1 Commits

Author SHA1 Message Date
NiccoloN 87922d994f multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 1h2m3s
2026-04-22 18:29:06 +02:00
179 changed files with 4165 additions and 16633 deletions
-12
View File
@@ -1,17 +1,5 @@
.zed
.idea .idea
**/.vscode **/.vscode
.claude .claude
.codex
AGENTS.md AGENTS.md
CMakeUserPresets.json
build build
build_release
cmake-build-debug
cmake-build-release
compile.sh
**/__*
-154
View File
@@ -1,159 +1,5 @@
# Raptor # Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
```
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` — materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
- `--maccel=PIM` — select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
## Rebuilding
Release build (fast):
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
## Build ## Build
### Protobuf ### Protobuf
File diff suppressed because it is too large Load Diff
@@ -13,9 +13,8 @@ name = "pimcore"
path = "src/lib/pimcore.rs" path = "src/lib/pimcore.rs"
[features] [features]
default = [] default = ["tracing"]
tracing = [] tracing = []
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
@@ -28,9 +27,3 @@ hex = "0"
paste = "1" paste = "1"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
statrs = {version="0.16", optional=true}
comfy-table = {version="7.1", optional=true}
plotly = {version="0.8", optional=true}
rayon = "1.12.0"
faer = "0.24.0"
faer-traits = "0.24.0"
@@ -1,19 +1,14 @@
use crate::{ use crate::{
cpu::{CPU, crossbar}, cpu::{CPU, crossbar}, instruction_set::{
instruction_set::{
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith, Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
helper::add_all, helper::add_all,
}, }, memory_manager::{
memory_manager::{
MemoryStorable, MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice}, type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
}, }, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
tracing::TRACER,
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
}; };
use aligned_vec::{AVec, ConstAlign}; use aligned_vec::{AVec, ConstAlign};
use anyhow::{Context, Result, ensure}; use anyhow::{Context, Result, ensure};
use rayon::prelude::*;
use paste::paste; use paste::paste;
use std::{borrow::Cow, cell::OnceCell, collections::HashMap}; use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
@@ -81,7 +76,8 @@ pub fn functor_to_name(functor: usize) -> &'static str {
/////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions////////////////// /////////////////Scalar/register Instructions//////////////////
/////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> { pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
{
TRACER.lock().unwrap().pre_sldi(cores, data); TRACER.lock().unwrap().pre_sldi(cores, data);
let (core_indx, rd, imm) = data.get_core_rd_imm(); let (core_indx, rd, imm) = data.get_core_rd_imm();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -233,30 +229,25 @@ where
[F]: UpcastSlice<T> + UpcastSlice<M>, [F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>, [M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
// Add faer::ComplexField HERE, directly bounding M for this function only M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data); TRACER.lock().unwrap().pre_mvm::<F,M,T>(cores, data);
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup(); let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let group: usize = group.try_into().context("group can not be negative")?; let group: usize = group.try_into().context("group can not be negative")?;
let core = cores.core(core_indx); let core = cores.core(core_indx);
let r1_val = core.register(r1); let r1_val = core.register(r1);
let rd_val = core.register(rd); let rd_val = core.register(rd);
let (memory, crossbars) = core.get_memory_crossbar(); let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap(); let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes(); let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width(); let crossbar_byte_width = crossbar.width();
//Fix this
let crossbar_elem_width = crossbar_byte_width / size_of::<M>(); let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!( ensure!(
crossbar_byte_width % size_of::<M>() == 0, crossbar_byte_width & size_of::<M>() == 0,
"M not divisor of the crosbbar size" "M not divisor of the crosbbar size"
); );
let crossbar_height = crossbar.height(); let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height; let crossbar_byte_size = crossbar_byte_width * crossbar_height;
@@ -266,29 +257,19 @@ where
let load = loads[0]; let load = loads[0];
let vec: Cow<[M]> = load.up(); let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0]; let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
let mut res = Vec::with_capacity(crossbar_elem_width);
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
partial.resize(vec.len(), M::from_f32(0.0));
// --- FAER IMPLEMENTATION --- for x in 0..crossbar_elem_width {
partial[0] = vec[0] * matrix[x];
// 1. Explicitly create a Matrix Reference (MatRef) for y in 1..crossbar_height {
let matrix_view = faer::mat::MatRef::from_row_major_slice( partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
matrix.as_ref(), }
crossbar_height,
crossbar_elem_width,
);
// 2. Explicitly create a Column Vector Reference (ColRef)
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
// 4. Convert back to standard Rust Vec
// try_as_slice() returns an Option<&[M]>.
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
// --- END FAER ---
let mut acc = add_all(partial.as_slice());
res.push(acc);
}
if relu != 0 { if relu != 0 {
res.iter_mut().for_each(|x| { res.iter_mut().for_each(|x| {
if *x < M::from_f32(0.0) { if *x < M::from_f32(0.0) {
@@ -296,16 +277,13 @@ where
} }
}); });
} }
ensure!( ensure!(
res.len() == crossbar_elem_width, res.len() == crossbar_elem_width,
"mvm generate a vector bigger thant it's requested elements" "mvm generate a vector bigger thant it's requested elements"
); );
let res_up: Cow<[T]> = res.as_slice().up(); let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref()); core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
Ok(InstructionStatus::Completed) Ok(InstructionStatus::Completed)
} }
@@ -339,7 +317,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vvadd::<F, T>(cores, data); TRACER.lock().unwrap().pre_vvadd::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -367,7 +345,7 @@ where
); );
let res_up: Cow<[T]> = res.as_slice().up(); let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref()); core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vvadd::<F, T>(cores, data); TRACER.lock().unwrap().post_vvadd::<F,T>(cores, data);
Ok(InstructionStatus::Completed) Ok(InstructionStatus::Completed)
} }
@@ -381,7 +359,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vvsub::<F, T>(cores, data); TRACER.lock().unwrap().pre_vvsub::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -422,7 +400,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vvmul::<F, T>(cores, data); TRACER.lock().unwrap().pre_vvmul::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -462,7 +440,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vvdmul::<F, T>(cores, data); TRACER.lock().unwrap().pre_vvdmul::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -498,7 +476,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vvmax::<F, T>(cores, data); TRACER.lock().unwrap().pre_vvmax::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -547,7 +525,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
TRACER.lock().unwrap().pre_vavg::<F, T>(cores, data); TRACER.lock().unwrap().pre_vavg::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -555,10 +533,7 @@ where
let r2_val = r2; let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported"); ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd); let rd_val = core.register(rd);
ensure!( ensure!(offset_select == 1, "Offset select cannot be different from 1");
offset_select == 1,
"Offset select cannot be different from 1"
);
let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?; let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0]; let load1 = loads[0];
@@ -580,7 +555,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>, F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{ {
TRACER.lock().unwrap().pre_vrelu::<F, T>(cores, data); TRACER.lock().unwrap().pre_vrelu::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -610,7 +585,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>, F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{ {
TRACER.lock().unwrap().pre_vtanh::<F, T>(cores, data); TRACER.lock().unwrap().pre_vtanh::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -638,7 +613,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>, F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{ {
TRACER.lock().unwrap().pre_vsigm::<F, T>(cores, data); TRACER.lock().unwrap().pre_vsigm::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -658,16 +633,13 @@ pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionSta
panic!("You are calling a placeholder, the real call is the generic version"); panic!("You are calling a placeholder, the real call is the generic version");
} }
pub(super) fn vsoftmax_impl<F, T>( pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
cores: &mut CPU,
data: InstructionData,
) -> Result<InstructionStatus>
where where
[F]: UpcastSlice<T>, [F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>, F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{ {
TRACER.lock().unwrap().pre_vsoftmax::<F, T>(cores, data); TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx); let core = cores.core(core_indx);
@@ -684,15 +656,16 @@ where
.reduce(|a, b| if a > b { a } else { b }) .reduce(|a, b| if a > b { a } else { b })
.unwrap(); .unwrap();
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect(); let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap(); let sum = exp_values
ensure!( .iter()
sum > 0.0.into(), .copied()
"vsoftmax normalization sum must be positive" .reduce(|a, b| a + b)
); .unwrap();
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect(); let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
let res_up: Cow<[T]> = res.as_slice().up(); let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref()); core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vsoftmax::<F, T>(cores, data); TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
Ok(InstructionStatus::Completed) Ok(InstructionStatus::Completed)
} }
@@ -776,10 +749,12 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
} }
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> { pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_send(cores, data);
Ok(InstructionStatus::Sending(data)) Ok(InstructionStatus::Sending(data))
} }
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> { pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_recv(cores, data);
Ok(InstructionStatus::Reciving(data)) Ok(InstructionStatus::Reciving(data))
} }
@@ -55,23 +55,15 @@ pub trait HasSigm {
impl HasSigm for f32 { impl HasSigm for f32 {
fn sigm(self) -> Self { fn sigm(self) -> Self {
if self >= 0.0 { let ex = self.exp();
1.0 / (1.0 + (-self).exp()) ex / (1.0 + ex)
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
} }
} }
impl HasSigm for f64 { impl HasSigm for f64 {
fn sigm(self) -> Self { fn sigm(self) -> Self {
if self >= 0.0 { let ex = self.exp();
1.0 / (1.0 + (-self).exp()) ex / (1.0 + ex)
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
} }
} }
@@ -169,9 +169,6 @@ impl<'a> Executable<'a> {
} }
} }
print_status(cores_instructions); print_status(cores_instructions);
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
} }
pub fn cpu(&self) -> &CPU<'a> { pub fn cpu(&self) -> &CPU<'a> {
@@ -58,20 +58,6 @@ where 'a : 'b
&& sender.internal_core == receiver.external_core && sender.internal_core == receiver.external_core
&& receiver.internal_core == sender.external_core && receiver.internal_core == sender.external_core
{ {
{
let sender = &mut core_instructions[sender.internal_core];
let pc = sender.program_counter;
let inst = sender.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_send(cpu, data);
}
{
let recv = &mut core_instructions[receiver.internal_core];
let pc = recv.program_counter;
let inst = recv.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_recv(cpu, data);
}
let [sender_core, reciver_core] = let [sender_core, reciver_core] =
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]); cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
let memory = sender_core let memory = sender_core
@@ -13,7 +13,7 @@ use crate::{
}; };
use std::io::Write; use std::io::Write;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))] #[cfg(not(feature = "tracing"))]
impl Trace { impl Trace {
/////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions////////////////// /////////////////Scalar/register Instructions//////////////////
@@ -1,32 +1,52 @@
mod tracing_isa;
mod disable; mod disable;
#[cfg(feature = "profile_time")] mod pretty_print;
mod profile; use std::{fs::File, path::{ PathBuf}};
#[cfg(feature = "profile_time")]
use profile::Trace;
#[cfg(feature = "tracing")]
mod trace;
#[cfg(feature = "tracing")]
use trace::Trace;
use crate::Executable;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
use std::path::PathBuf;
use std::sync::{LazyLock, Mutex}; use std::sync::{LazyLock, Mutex};
#[cfg(not(any(feature = "tracing", feature = "profile_time")))] use crate::Executable;
pub struct Trace {}
#[cfg(not(any(feature = "tracing", feature = "profile_time")))] #[cfg(feature = "tracing")]
impl Trace { pub struct Trace {
fn new() -> Self { out_files : Vec<File>
Self {}
}
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
} }
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into()); #[cfg(feature = "tracing")]
impl Trace {
fn new() -> Self {
Self { out_files : Vec::new()}
}
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
#[cfg(not(feature = "tracing"))]
pub struct Trace {
}
#[cfg(not(feature = "tracing"))]
impl Trace {
fn new() -> Self {
Self { }
}
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
}
}
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
@@ -1,73 +0,0 @@
use std::{collections::HashMap, path::PathBuf, time::Instant};
use crate::tracing::profile::profile_analysis::{
analyze_timings, generate_interactive_report, print_textual_report,
};
pub mod profile_analysis;
pub mod profile_isa;
pub struct Trace {
instruction_times: HashMap<String, Vec<(u128,u128)>>,
core_start_time: HashMap<usize, Option<Instant>>,
start_time: Instant,
}
impl Trace {
pub fn new() -> Self {
let mut instruction_times = HashMap::new();
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
Self {
instruction_times,
core_start_time: HashMap::new(),
start_time: Instant::now()
}
}
pub fn init(&mut self, num_core: usize, path: PathBuf) {
for i in 0..num_core {
self.core_start_time.insert(i, None);
}
}
pub fn report(&self) {
let res = analyze_timings(&self.instruction_times);
print_textual_report(&res);
generate_interactive_report(
&self.instruction_times,
&["mvmul", "recv"],
"/tmp/report.html",
);
}
}
@@ -1,192 +0,0 @@
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
use std::collections::HashMap;
#[derive(Debug)]
pub struct InstructionStats {
pub name: String,
pub count: usize,
pub total_time: u128,
pub min: f64,
pub max: f64,
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub cv: f64,
pub p95: f64,
pub p99: f64,
pub skewness: f64,
pub kurtosis: f64,
}
fn format_time(ns: f64) -> String {
if ns.is_nan() {
return "NaN".to_string();
}
if ns >= 1_000_000_000.0 {
format!("{:.2} s", ns / 1_000_000_000.0)
} else if ns >= 1_000_000.0 {
format!("{:.2} ms", ns / 1_000_000.0)
} else if ns >= 1_000.0 {
format!("{:.2} µs", ns / 1_000.0)
} else {
format!("{:.2} ns", ns)
}
}
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
let n = times.len() as f64;
if n < 4.0 || std_dev == 0.0 {
return (f64::NAN, f64::NAN);
}
let mut sum_m3 = 0.0;
let mut sum_m4 = 0.0;
for &x in times {
let deviation = x - mean;
sum_m3 += deviation.powi(3);
sum_m4 += deviation.powi(4);
}
let m3 = sum_m3 / n;
let m4 = sum_m4 / n;
let skewness = m3 / std_dev.powi(3);
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
(skewness, kurtosis)
}
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
let mut results = Vec::new();
for (instruction, times) in timings {
let count = times.len();
if count == 0 {
continue;
}
// Extract ONLY the duration (the second element of the tuple) for stats
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
let total_time: u128 = durations.iter().sum();
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
let mut data = Data::new(f64_times.clone());
let mean = data.mean().unwrap_or(0.0);
let std_dev = data.std_dev().unwrap_or(0.0);
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
results.push(InstructionStats {
name: instruction.clone(),
count,
total_time,
min: data.min(),
max: data.max(),
mean,
median: data.median(),
std_dev,
cv,
p95: data.percentile(95),
p99: data.percentile(99),
skewness,
kurtosis,
});
}
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
results
}
pub fn print_textual_report(stats: &[InstructionStats]) {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
"Instruction",
"Count",
"Total Time",
"Mean",
"Median",
"Min",
"Max",
"P95",
"P99",
"StdDev",
"CV",
"Skewness",
"Kurtosis",
]);
for stat in stats {
table.add_row(vec![
Cell::new(&stat.name),
Cell::new(stat.count.to_string()),
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
Cell::new(format_time(stat.mean)),
Cell::new(format_time(stat.median)),
Cell::new(format_time(stat.min)),
Cell::new(format_time(stat.max)),
Cell::new(format_time(stat.p95)),
Cell::new(format_time(stat.p99)),
Cell::new(format_time(stat.std_dev)),
Cell::new(format!("{:.3}", stat.cv)),
Cell::new(format!("{:.2}", stat.skewness)),
Cell::new(format!("{:.2}", stat.kurtosis)),
]);
}
println!("{table}");
}
pub fn generate_interactive_report(
timings: &HashMap<String, Vec<(u128, u128)>>,
instructions_to_plot: &[&str], // <-- NEW: Only plot these
file_path: &str,
) {
use plotly::common::{Mode, Marker, Line};
use plotly::layout::{Axis, Layout};
use plotly::{Plot, Scatter};
use std::collections::HashMap;
let mut plot = Plot::new();
for &instruction_name in instructions_to_plot {
// Only proceed if the instruction exists in our timings map
if let Some(times) = timings.get(instruction_name) {
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
let text_array: Vec<String> = times.iter()
.map(|&(_, dur)| format_time(dur as f64))
.collect();
let trace = Scatter::new(x_axis, y_axis)
.name(instruction_name)
.mode(Mode::LinesMarkers)
.marker(Marker::new().size(4).opacity(0.6))
.line(Line::new().width(1.0))
.text_array(text_array)
.hover_info(plotly::common::HoverInfo::All);
plot.add_trace(trace);
}
}
let layout = Layout::new()
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
plot.set_layout(layout);
plot.write_html(file_path);
println!("🌐 Interactive timeline saved to {}", file_path);
}
@@ -1,364 +0,0 @@
use crate::{
cpu::CPU,
instruction_set::instruction_data::InstructionData,
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
},
tracing::Trace,
utility::{add_offset_r1, add_offset_rd},
};
use std::io::Write;
use std::time::Instant;
#[cfg(feature = "profile_time")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
if self.core_start_time.get(&core_indx).unwrap().is_none() {
self.core_start_time.insert(core_indx, Some(Instant::now()));
}
}
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
let Self {
instruction_times,
core_start_time,
start_time,
} = self;
let now = Instant::now();
instruction_times
.get_mut(name)
.unwrap()
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
self.core_start_time.insert(core_indx, None);
}
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sldi");
}
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sld");
}
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sadd");
}
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ssub");
}
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smul");
}
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "saddi");
}
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smuli");
}
/////////////////////////////////////////////////////////////////
///////////////////Matrix/vector Instructions////////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "setbw");
}
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "mvmul");
}
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvadd");
}
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvsub");
}
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmul");
}
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvdmul");
}
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmax");
}
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vavg");
}
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vrelu");
}
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vtanh");
}
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsigm");
}
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsoftmax");
}
/////////////////////////////////////////////////////////////////
/////Communication/synchronization Instructions/////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ld");
}
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "st");
}
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lldi");
}
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lmv");
}
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "send");
}
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "recv");
}
}
@@ -1,28 +0,0 @@
use std::{fs::File, path::PathBuf};
pub mod pretty_print;
pub mod tracing_isa;
pub struct Trace {
out_files: Vec<File>,
}
impl Trace {
pub fn new() -> Self {
Self {
out_files: Vec::new(),
}
}
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
@@ -1,4 +1,4 @@
use crate::{tracing::trace::pretty_print, utility::add_offset_r2}; use crate::tracing::pretty_print;
use std::fs::File; use std::fs::File;
use crate::{ use crate::{
@@ -13,6 +13,7 @@ use crate::{
}; };
use std::io::Write; use std::io::Write;
#[cfg(feature = "tracing")]
impl Trace { impl Trace {
/////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions////////////////// /////////////////Scalar/register Instructions//////////////////
@@ -283,6 +284,7 @@ impl Trace {
M: UpcastDestTraits<M> + MemoryStorable + FromFloat, M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
use crate::tracing::pretty_print;
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup(); let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let file: &mut File = self let file: &mut File = self
@@ -356,6 +358,8 @@ impl Trace {
T: UpcastDestTraits<T> + MemoryStorable, T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable, F: UpcastDestTraits<F> + MemoryStorable,
{ {
use crate::{tracing::pretty_print, utility::add_offset_r2};
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) = let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self let file: &mut File = self
@@ -986,6 +990,8 @@ impl Trace {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) { pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) = let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self let file: &mut File = self
@@ -1038,6 +1044,8 @@ impl Trace {
} }
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) { pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) = let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self let file: &mut File = self
@@ -1130,6 +1138,7 @@ impl Trace {
} }
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) { fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) = let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset(); data.get_core_rd_r1_r2_immlen_offset();
+1 -8
View File
@@ -1,12 +1,5 @@
add_pim_library(OMPimCommon add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp PimCommon.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
IR/WeightUtils.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
-259
View File
@@ -1,259 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
}
namespace {
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (mlir::isa<mlir::BlockArgument>(value))
return value;
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs + *rhs;
}
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs - *rhs;
}
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs * *rhs;
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return mlir::failure();
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!integerAttr)
return mlir::failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return mlir::failure();
offsets.push_back(*resolvedOffset);
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return mlir::failure();
sizes.push_back(*resolvedSize);
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return mlir::failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
} // namespace onnx_mlir
-43
View File
@@ -1,43 +0,0 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
/// byte offset after peeling aliases, casts, and contiguous subviews.
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
/// Records compile-time facts used when interpreting address arithmetic and
/// loop-carried aliases inside PIM regions.
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
} // namespace onnx_mlir
-745
View File
@@ -1,745 +0,0 @@
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
namespace compact_asm {
using namespace mlir;
enum class ListDelimiter {
Square,
Paren
};
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
template <typename StreamT>
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
}
template <typename StreamT>
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0) {
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
}
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
for (size_t index = 0; index < values.size();) {
if (index != 0)
stream << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printOpenDelimiter(stream, ListDelimiter::Paren);
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
printCloseDelimiter(stream, ListDelimiter::Paren);
stream << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
stream << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
stream << " by " << flat.step;
if (flat.repeatCount > 1)
stream << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
stream << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
stream << flat.firstValue;
index += flat.covered;
break;
}
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
printOpenDelimiter(stream, delimiter);
printCompressedIntegerEntries(stream, values);
printCloseDelimiter(stream, delimiter);
}
template <typename IntT>
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedValueSequence(printer, values);
printCloseDelimiter(printer, delimiter);
}
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedTypeSequence(printer, types);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
return success();
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
return success();
}
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
return failure();
return parseOptionalCloseDelimiter(parser, delimiter);
}
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
arguments.push_back(argument);
return success();
}
} // namespace compact_asm
} // namespace onnx_mlir
#endif
-67
View File
@@ -1,67 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
}
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
-24
View File
@@ -1,24 +0,0 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
/// Returns true for ops in a `pim.core` body that only participate in static
/// address or index computation and therefore do not emit PIM instructions.
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
/// their bounds are known and invoking `callback` only on instruction-emitting
/// operations.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir
-45
View File
@@ -1,45 +0,0 @@
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
if (!moduleOp)
return mlir::failure();
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return mlir::failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return mlir::failure();
}
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return mlir::failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
return mainGraphFunc;
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return mlir::failure();
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
/// Resolves the function the PIM pipeline should treat as its entry point.
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
/// non-external function if the module is otherwise unambiguous.
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
-89
View File
@@ -1,89 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir
-22
View File
@@ -1,22 +0,0 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir
-108
View File
@@ -1,108 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
});
}
} // namespace onnx_mlir
-29
View File
@@ -1,29 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
} // namespace onnx_mlir
+546
View File
@@ -0,0 +1,546 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
+70 -10
View File
@@ -7,22 +7,82 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId"; struct ResolvedContiguousAddress {
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds"; mlir::Value base;
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir } // namespace onnx_mlir
-27
View File
@@ -1,27 +0,0 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
moduleOp.print(os, flags);
os.flush();
file.close();
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include <string>
namespace onnx_mlir {
/// Emits a MLIR snapshot under the current compiler output
/// directory for pass-level debugging.
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
} // namespace onnx_mlir
-41
View File
@@ -1,41 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
return op->emitOpError() << "requires statically shaped " << valueDescription;
}
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks) {
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
if (supportedRanks.empty())
return diag;
diag << "; supported rank";
if (supportedRanks.size() != 1)
diag << 's';
diag << ' ';
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
return diag;
}
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
}
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
llvm::StringRef action,
llvm::StringRef path,
const std::error_code& errorCode) {
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
return mlir::failure();
}
} // namespace onnx_mlir::pim
-38
View File
@@ -1,38 +0,0 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <system_error>
namespace onnx_mlir::pim {
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
/// accepted by the current lowering/codegen path.
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks);
/// Emits a consistent diagnostic for missing symbol/global references.
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
/// the relevant IR location.
mlir::LogicalResult
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
template <typename T>
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
return mlir::success(succeeded(value));
}
} // namespace onnx_mlir::pim
@@ -1,24 +0,0 @@
#include <filesystem>
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
} // namespace onnx_mlir
@@ -1,13 +0,0 @@
#pragma once
#include <string>
namespace onnx_mlir {
/// Returns the directory that should hold PIM artifacts/debug dumps for the
/// current compiler invocation.
std::string getOutputDir();
void createDirectory(const std::string& directory);
} // namespace onnx_mlir
-3
View File
@@ -15,10 +15,7 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp PimCodeGen.cpp
PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
-123
View File
@@ -1,123 +0,0 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
std::error_code errorCode;
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["adc_count"] = 16;
configJson["cell_precision"] = 2;
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
} // namespace onnx_mlir
-26
View File
@@ -1,26 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
llvm::json::Object xbarsPerArrayGroup,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
-126
View File
@@ -1,126 +0,0 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op)) {
pim::PimHaltOp::create(builder, op.getLoc());
continue;
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
builder,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
hostSource = memcpBatchOp.getHostSource();
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
hostSource,
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return callback(scalarCore);
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+339 -396
View File
@@ -1,48 +1,30 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <absl/types/compare.h>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cstdint> #include <cmath>
#include <fstream>
#include <string> #include <string>
#include <utility> #include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/IR/CompactAsmUtils.hpp" #include "Conversion/ONNXToSpatial/Common.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
@@ -71,22 +53,9 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants; SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases; SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()) {
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")) {
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr, index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted) if (inserted)
gatherMemEntry(getGlobalOp.getResult()); gatherMemEntry(getGlobalOp.getResult());
@@ -95,6 +64,9 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
} }
}); });
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
funcOp.walk([&](memref::AllocOp allocOp) { funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>()) if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult()); gatherMemEntry(allocOp.getResult());
@@ -112,60 +84,6 @@ void PimMemory::allocateCore(Operation* op) {
allocateGatheredMemory(); allocateGatheredMemory();
} }
std::string formatMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
double size = static_cast<double>(bytes);
while (size >= 1024 && i < 6) {
size /= 1024;
i++;
}
// Formats to 2 decimal places
std::string out;
llvm::raw_string_ostream rss(out);
rss << llvm::format("%.2f ", size) << units[i];
return rss.str();
}
static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
os << "\tNumber of allocas: " << row.numAlloca << "\n";
os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n";
os << "\tNumber of globals: " << row.numGlobal << "\n";
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : globalMemEntriesMap) {
if (auto op = val.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
void PimMemory::remove(mlir::Value val) {
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
globalMemEntriesMap.erase(removeIter);
}
MemEntry PimMemory::getMemEntry(mlir::Value value) const { MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value); auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
@@ -206,99 +124,6 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
return iter->second.address + resolvedAddress->byteOffset; return iter->second.address + resolvedAddress->byteOffset;
} }
void PimAcceleratorMemory::reportHost() {
hostReportRow = hostMem.getReportRow();
}
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = row;
reportEntries.push_back(std::move(entry));
}
void PimAcceleratorMemory::flushReport() {
if (!fileReport.is_open())
return;
llvm::raw_os_ostream os(fileReport);
if (hostReportRow.has_value()) {
os << "Host:\n";
printMemoryReportRow(os, *hostReportRow);
}
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
if (lhs.kind != rhs.kind)
return lhs.kind == MemoryReportEntry::Kind::Batch;
const MemoryReportRow& lhsRow = lhs.row;
const MemoryReportRow& rhsRow = rhs.row;
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
if (lhsRow.numAlloca != rhsRow.numAlloca)
return lhsRow.numAlloca > rhsRow.numAlloca;
if (lhsRow.sizeGlobal != rhsRow.sizeGlobal)
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
if (lhsRow.numGlobal != rhsRow.numGlobal)
return lhsRow.numGlobal > rhsRow.numGlobal;
return lhs.id < rhs.id;
});
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row) {
++runEnd;
}
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
printMemoryReportRow(os, reportEntries[index].row);
if (runEnd < reportEntries.size())
os << "\n";
index = runEnd;
}
}
os.flush();
fileReport.close();
}
void PimAcceleratorMemory::clean(mlir::Operation* op) {
for (auto value : op->getResults()) {
hostMem.remove(value);
for (auto& device : deviceMem)
device.second.remove(value);
}
}
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
json::Object offset; json::Object offset;
offset["offset_select"] = 0; offset["offset_select"] = 0;
@@ -306,12 +131,6 @@ json::Object PimCodeGen::createEmptyOffset() {
return offset; return offset;
} }
size_t PimCodeGen::remapCoreId(size_t coreId) const {
auto it = emittedCoreIds.find(coreId);
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
return it->second;
}
static json::Object createRs1OnlyOffset() { static json::Object createRs1OnlyOffset() {
json::Object offset; json::Object offset;
offset["offset_select"] = 1; offset["offset_select"] = 1;
@@ -371,7 +190,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
json::Object json; json::Object json;
json["op"] = opName; json["op"] = opName;
json["rd"] = 0; json["rd"] = 0;
json["core"] = remapCoreId(coreId); json["core"] = coreId;
json["size"] = size; json["size"] = size;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
emitInstruction(std::move(json)); emitInstruction(std::move(json));
@@ -423,62 +242,10 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
} }
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
} }
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy> template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp, MVMTy mvmLikeOp,
@@ -489,6 +256,11 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
} }
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge); auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge); auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
@@ -640,8 +412,6 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
@@ -704,59 +474,67 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes"; return std::to_string(size) + " Bytes";
} }
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) { static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
SmallVector<unsigned, 8> indices; SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) { auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex)) if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex); indices.push_back(weightIndex);
}; };
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices); llvm::sort(indices);
return indices; return indices;
} }
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) { /// Write global constant data into a binary memory image at their allocated addresses.
return getUsedWeightIndices(coreOp.getBody().front()); static OnnxMlirCompilerErrorCodes
} writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) { SmallPtrSet<Operation*, 16> writtenGlobals;
SmallVector<Operation*> coreLikeOps; funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
for (Operation& op : funcOp.getBody().front()) if (hasWeightAlways(getGlobalOp))
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op)) return;
coreLikeOps.push_back(&op); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return coreLikeOps; if (!globalOp)
} return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
static void aliasMaterializedHostGlobals(ModuleOp moduleOp, return;
func::FuncOp funcOp, auto initialValue = globalOp.getInitialValue();
pim::PimCoreOp coreOp, if (!initialValue)
PimAcceleratorMemory& memory) { return;
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) { auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult())) if (!denseAttr)
return; return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
if (!targetGlobal) ArrayRef<char> rawData = denseAttr.getRawData();
return; char* dst = memoryBuffer.data() + memEntry.address;
mlir::Value aliasedValue; if (denseAttr.isSplat()) {
funcOp.walk([&](memref::GetGlobalOp candidate) { size_t elementSize = rawData.size();
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult())) assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
return; for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal) std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
aliasedValue = candidate.getResult(); }
}); else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
if (aliasedValue) std::memcpy(dst, rawData.data(), rawData.size());
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue]; }
}); });
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
} }
/// Dispatch all operations in a core region to the appropriate code generator. /// Dispatch all operations in a core region to the appropriate code generator.
@@ -775,16 +553,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge); coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op)) else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op)) else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge); coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge); coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op)) else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -807,10 +581,9 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op)) else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else { else {
op.emitError("Unsupported codegen for this operation"); op.emitError("Unsupported codegen for this operation");
op.dump();
return failure(); return failure();
} }
processedOperations++; processedOperations++;
@@ -819,6 +592,225 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
return failed(result) ? -1 : static_cast<int64_t>(processedOperations); return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
} }
/// Write crossbar weight matrices as padded binary files for a single core.
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
StringRef coreWeightsDirPath,
json::Array& xbarsPerGroup) {
int64_t xbarSize = crossbarSize.getValue();
std::error_code errorCode;
size_t weightIndex = 0;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
return CompilerSuccess;
}
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(!getGlobalOp && "Weight is not from a memref.get_global");
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
assert(!globalOp && "Could not find memref.global");
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
assert(!initialValue && "memref.global has no initial value");
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
assert(!denseAttr && "memref.global initial value is not dense");
}
if (mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreOp].insert(weightToFile);
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
}
}
return mapCoreWeightToFileName;
}
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t coreCount,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
// +1 because pimsim-nn also considers the host as a core
configJson["core_cnt"] = coreCount + 1;
// TODO: Should this be based on the floating point type used in the model?
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
// Number of ADC for MVM units
configJson["adc_count"] = 16;
// The bit precision of each ADC
configJson["cell_precision"] = 2;
// Crossbar configuration
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
// Memory layout of inputs and outputs
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) { OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) { if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) { if (auto error = sys::fs::create_directory(outputDirPath)) {
@@ -834,134 +826,85 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimAcceleratorMemory memory; PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp); memory.hostMem.allocateHost(moduleOp, funcOp);
memory.reportHost();
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath)) if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err; return err;
if (auto err = writeHostCoreJson(outputDirPath)) // Write empty host core file
return err; std::error_code errorCode;
auto outputHostCorePath = outputDirPath + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains 2 random instructions, just to make pimsim-nn happy
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
// For each core, specify the number of crossbar per array group. // For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group. // This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup; json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0; size_t coreCount = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder // Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp); for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
llvm::DenseMap<size_t, size_t> emittedCoreIds; auto coreId = coreOp.getCoreId();
size_t nextEmittedCoreId = 1; coreCount++;
for (Operation* op : coreLikeOps) { std::error_code errorCode;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) { auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId()); raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (!emittedCoreIds.contains(originalCoreId)) if (errorCode) {
emittedCoreIds[originalCoreId] = nextEmittedCoreId++; errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
continue; return InvalidOutputFileAccess;
}
coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
// Remove trailing comma, close JSON array
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
// Write crossbar weights for this core
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
} }
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op); auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
auto batchCoreIds = getBatchCoreIds(coreBatchOp); json::Array xbarsPerGroup;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) { for (unsigned index : getUsedWeightIndices(coreOp)) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]); if (index >= coreOp.getWeights().size()) {
if (!emittedCoreIds.contains(originalCoreId)) coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
emittedCoreIds[originalCoreId] = nextEmittedCoreId++; assert(index < coreOp.getWeights().size() && "Weight index is out of range");
} }
} mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index);
for (Operation* op : coreLikeOps) { assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes { auto& fileName = mapWeightToFile[weight];
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId()); if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
size_t coreId = emittedCoreIds.lookup(originalCoreId); coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
maxCoreId = std::max(maxCoreId, coreId); errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
std::error_code errorCode; << '\n';
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
json::Array xbarsPerGroup;
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
<< "\nError:" << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
if (temporaryCore)
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
return CompilerSuccess;
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false))
return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
.getReportRow());
continue;
} }
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op); xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
laneResult = emitCore(coreOp, true);
if (laneResult == CompilerSuccess)
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
}
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
} }
memory.flushReport(); return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
} }
+5 -62
View File
@@ -1,14 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/Operation.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -21,34 +14,12 @@ struct MemEntry {
size_t size; size_t size;
}; };
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
uint64_t numGlobal = 0;
uint64_t sizeGlobal = 0;
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
}
};
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
};
class PimMemory { class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries; llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4; size_t minAlignment = 4;
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
@@ -62,8 +33,6 @@ public:
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op); void allocateCore(mlir::Operation* op);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; } size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const; MemEntry getMemEntry(mlir::Value value) const;
@@ -76,43 +45,23 @@ public:
private: private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem; llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
: hostMem(memEntriesMap) { : hostMem(memEntriesMap) {}
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/reports/";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/memory_report.txt", std::ios::out);
fileReport = std::move(file);
}
PimMemory& getOrCreateDeviceMem(size_t id); PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
void flushReport();
void clean(mlir::Operation* op);
}; };
class PimCodeGen { class PimCodeGen {
PimAcceleratorMemory& memory; PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream; llvm::raw_fd_ostream& coreFileStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge); return memory.getValueAddress(value, knowledge);
} }
size_t remapCoreId(size_t coreId) const;
static llvm::json::Object createEmptyOffset(); static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const; void emitInstruction(llvm::json::Object instruction) const;
@@ -134,20 +83,15 @@ class PimCodeGen {
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public: public:
PimCodeGen(PimAcceleratorMemory& memory, PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
llvm::raw_fd_ostream& coreJson, : memory(memory), coreFileStream(coreJson) {}
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy> template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge); void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
@@ -162,7 +106,6 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
}; };
+15 -2
View File
@@ -1,3 +1,16 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
@@ -28,7 +41,7 @@ llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256)); crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
llvm::cl::opt<long> coresCount("core-count", llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
@@ -38,7 +51,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size", "dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. " llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."), "Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(4000)); llvm::cl::init(1024));
llvm::cl::opt<bool> llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error", ignoreConcatError("ignore-concat-error",
-220
View File
@@ -1,220 +0,0 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
}
return mapCoreWeightToFileName;
}
} // namespace onnx_mlir
-16
View File
@@ -1,16 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include <string>
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,11 +3,6 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
@@ -23,9 +18,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common/ComputeRegionBuilder.cpp Common.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -1,12 +1,24 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "ShapeTilingUtils.hpp" #include <cassert>
#include <optional>
#include <utility>
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -32,29 +44,10 @@ SmallVector<Value> sliceTensor(
for (int64_t i = 0; i < numSlices; i++) { for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize); offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
int64_t currentSliceSize = sliceSize; if (i == numSlices - 1 && lastSliceSize != 0)
if (i == numSlices - 1 && lastSliceSize != 0) {
currentSliceSize = lastSliceSize;
sizes[axis] = rewriter.getIndexAttr(lastSliceSize); sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
}
SmallVector<int64_t> sliceShape(shape.begin(), shape.end()); Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
sliceShape[axis] = currentSliceSize;
auto sliceType =
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
});
slice = sliceCompute.getResult(0);
}
slices.push_back(slice); slices.push_back(slice);
} }
@@ -114,4 +107,31 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
return tensor::SplatOp::create(rewriter, loc, type, elementValue); return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
} // namespace onnx_mlir Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir
+279
View File
@@ -0,0 +1,279 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir
@@ -1,7 +0,0 @@
#pragma once
#include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,39 +0,0 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir
@@ -1,153 +0,0 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,144 +0,0 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -1,114 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir
@@ -1,18 +0,0 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir
@@ -1,32 +0,0 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -1,75 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
}
} // namespace onnx_mlir
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
} // namespace onnx_mlir
@@ -1,29 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,4 +1,3 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -8,18 +7,25 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include "Common/Common.hpp" #include <fstream>
#include <iterator>
#include <utility>
#include "Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -27,8 +33,12 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace { namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> { struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
@@ -38,64 +48,33 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override; void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
}; };
} // namespace } // namespace
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLocs;
sourceTypes.reserve(funcOp.getNumArguments());
sourceLocs.reserve(funcOp.getNumArguments());
for (Value source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, blockArg);
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
RewritePatternSet prePatterns(ctx); RewritePatternSet mergeActivationPatterns(ctx);
populatePrePatterns(prePatterns, ctx); mergeActivationPatterns.add<onnxToArithConstant>(ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n"; mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) { if (failed(entryFunc)) {
signalPassFailure(); signalPassFailure();
@@ -108,7 +87,8 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect, tensor::TensorDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>(); target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>(); target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>(); target.addIllegalOp<ONNXMulOp>();
@@ -127,23 +107,32 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXReduceMeanV13Op>(); target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>(); target.addIllegalOp<ONNXSplitOp>();
RewritePatternSet conversionPatterns(ctx); RewritePatternSet patterns(ctx);
populateConversionPatterns(conversionPatterns, ctx); patterns.add<removeLRN>(ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
signalPassFailure(); populateElementwisePatterns(patterns, ctx);
return; populateGemmPatterns(patterns, ctx);
} populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
RewritePatternSet earlyPostPatterns(ctx); populateReduceMeanPatterns(patterns, ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx); populateReluPatterns(patterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (Operation& op : entryFunc->getFunctionBody().front().getOperations()) for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op)) if (isa<spatial::SpatCompute>(op))
computeOpsCount++; computeOpsCount++;
@@ -160,24 +149,334 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
RewritePatternSet postPatterns(ctx); if (failed(promoteConstantInputsToWeights(*entryFunc))) {
populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { mergeTriviallyConnectedComputes(*entryFunc);
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
} }
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
}
return false;
}
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
}
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> trivialComputes;
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
if (compute->hasOneUse()) {
auto& use = *compute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto& computeUse = *compute->getUses().begin();
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto& use = *newCompute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
for (Value result : compute->getResults())
result.dropAllUses();
compute.erase();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp);
});
}
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
continue;
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto& oldBlock = compute.getBody().front();
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
if (failed(clonedValue))
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
mapper.map(oldArg, *clonedValue);
}
for (auto& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
compute.replaceAllUsesWith(newCompute);
compute.erase();
}
return success();
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); } std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -5,8 +5,6 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -7,10 +7,11 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -146,148 +147,162 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
} }
static Value createIm2colRowComputes(Value x, static SmallVector<Value> createIm2colRowComputes(Value x,
RankedTensorType xType, RankedTensorType xType,
RankedTensorType im2colType, RankedTensorType im2colType,
RankedTensorType im2colRowType, RankedTensorType im2colRowType,
RankedTensorType gemmInputRowsType, RankedTensorType gemmInputRowType,
int64_t batchSize, int64_t batchSize,
int64_t numChannelsIn, int64_t numChannelsIn,
int64_t xHeight, int64_t xHeight,
int64_t xWidth, int64_t xWidth,
int64_t wHeight, int64_t wHeight,
int64_t wWidth, int64_t wWidth,
int64_t padHeightBegin, int64_t padHeightBegin,
int64_t padHeightEnd, int64_t padHeightEnd,
int64_t padWidthBegin, int64_t padWidthBegin,
int64_t padWidthEnd, int64_t padWidthEnd,
int64_t strideHeight, int64_t strideHeight,
int64_t strideWidth, int64_t strideWidth,
int64_t dilationHeight, int64_t dilationHeight,
int64_t dilationWidth, int64_t dilationWidth,
int64_t outWidth, int64_t outWidth,
int64_t patchSize, int64_t patchSize,
int64_t numPatches, int64_t numPatches,
int64_t numPatchesPerBatch, int64_t numPatchesPerBatch,
int64_t packFactor, int64_t packFactor,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
auto im2colComputeOp = SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
Value paddedInput = xArg; Value paddedInput = xArg;
// Pad input with zeros if needed: // Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin), rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)}; rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0), SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd), rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)}; rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block(); auto* padBlock = new Block();
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc); padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock); padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock); rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult()); tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp); rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult(); paddedInput = padOp.getResult();
} }
// Build im2col [numPatches, patchSize] incrementally to keep the IR small // Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step. // until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody()); rewriter.setInsertionPointToStart(im2colLoop.getBody());
Value patchIndex = im2colLoop.getInductionVar(); Value patchIndex = im2colLoop.getInductionVar();
Value im2colAcc = im2colLoop.getRegionIterArgs().front(); Value im2colAcc = im2colLoop.getRegionIterArgs().front();
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn), rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight), rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)}; rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight), rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)}; rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
Value row = tensor::CollapseShapeOp::create(rewriter, Value row = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
im2colRowType, im2colRowType,
patch, patch,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0}, {0},
{1, 2, 3} {1, 2, 3}
});
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
}); });
return im2colComputeOp.getResult(0); SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
SmallVector<Value> rowResults;
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
}
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
});
SmallVector<Value> rows;
rows.reserve(im2colComputeOp.getNumResults());
for (Value result : im2colComputeOp.getResults())
rows.push_back(result);
return rows;
} }
static Value createCollectedConvOutput(ValueRange gemmRows, static Value createCollectedConvOutput(ValueRange gemmRows,
@@ -305,12 +320,16 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut; Value gemmOut;
if (packFactor == 1) { if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
} }
else { else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); Value packedOutput =
gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
expandedType, expandedType,
@@ -369,34 +388,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType()); auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType()); auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) { assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); assert("Only support 2D convolution" && xType.getRank() == 4);
return failure();
} // We need to understand what is group
if (!wType.hasStaticShape()) { assert("Only support group=1" && convOp.getGroup() == 1);
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0); const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1); const int64_t numChannelsIn = xType.getDimSize(1);
@@ -413,19 +409,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const auto dilationsAttr = convOp.getDilations(); const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads(); const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -466,10 +449,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd; padWidthBegin = totalPadW - padWidthEnd;
} }
} }
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0 // "NOTSET" or "VALID" -> all pads stay 0
} }
@@ -526,42 +505,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better. // and optionally repack several old rows into one GEMM row to use the available crossbar size better.
// //
// We want to process N pixels at the same time. Instead of doing N separate operations // The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix // the row it needs instead of receiving a full packed tensor and slicing it locally.
// containing N copies of W^T and concatenate N im2col rows into one longer row: auto gemmInputRowType =
// A_packed: [ceil(numPatches / N), N * patchSize] RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
// B_packed: [N * patchSize, N * cOut] auto gemmOutputRowType =
// Y_packed: [ceil(numPatches / N), N * cOut] RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType); xType,
auto gemmOutputRowsType = im2colType,
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); rowType,
Value gemmInputRows = createIm2colRowComputes(x, gemmInputRowType,
xType, batchSize,
im2colType, numChannelsIn,
rowType, xHeight,
gemmInputRowsType, xWidth,
batchSize, wHeight,
numChannelsIn, wWidth,
xHeight, padHeightBegin,
xWidth, padHeightEnd,
wHeight, padWidthBegin,
wWidth, padWidthEnd,
padHeightBegin, strideHeight,
padHeightEnd, strideWidth,
padWidthBegin, dilationHeight,
padWidthEnd, dilationWidth,
strideHeight, outWidth,
strideWidth, patchSize,
dilationHeight, numPatches,
dilationWidth, numPatchesPerBatch,
outWidth, effectiveMaxParallelPixels,
patchSize, rewriter,
numPatches, loc);
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmB = buildPackedWeight(wDenseAttr, Value gemmB = buildPackedWeight(wDenseAttr,
wTrans, wTrans,
@@ -577,20 +552,25 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value gemmC = buildPackedBias( Value gemmC = buildPackedBias(
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value gemmRows = ONNXGemmOp::create(rewriter, SmallVector<Value> gemmRows;
loc, gemmRows.reserve(gemmInputRows.size());
gemmOutputRowsType, for (Value gemmInputRow : gemmInputRows) {
gemmInputRows, Value gemmRow = ONNXGemmOp::create(rewriter,
gemmB, loc,
gemmC, gemmOutputRowType,
rewriter.getF32FloatAttr(1.0f), gemmInputRow,
rewriter.getF32FloatAttr(1.0f), gemmB,
rewriter.getBoolAttr(false), gemmC,
rewriter.getBoolAttr(false)) rewriter.getF32FloatAttr(1.0f),
.getY(); rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmRows.push_back(gemmRow);
}
rewriter.replaceOp(convOp, rewriter.replaceOp(convOp,
createCollectedConvOutput(ValueRange {gemmRows}, createCollectedConvOutput(gemmRows,
convOp.getType(), convOp.getType(),
gemmOutType, gemmOutType,
nhwcType, nhwcType,
@@ -5,9 +5,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -16,6 +15,13 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) { static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>()) if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue()); return dyn_cast<DenseElementsAttr>(constantOp.getValue());
@@ -1,17 +1,16 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -50,45 +49,6 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
} }
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isHostFoldableValue(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
if (isHostFoldableValue(value))
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
value,
SmallVector<ReassociationIndices> {
{0, 1}
});
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
input,
SmallVector<ReassociationIndices> {
{0, 1}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return computeOp.getResult(0);
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> { struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -105,72 +65,6 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
RankedTensorType matrixType,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numRows = matrixType.getDimSize(0);
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
if (isHostFoldableValue(matrix)) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
}
auto buildRowSlices = [&](Value matrixArg) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
};
auto cloneBatchInputChainIntoSliceCompute =
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
Value transformedMatrix = input;
if (!chainOps.empty()) {
IRMapping mapper;
mapper.map(rootValue, input);
for (Operation* chainOp : chainOps)
rewriter.clone(*chainOp, mapper);
transformedMatrix = cast<Value>(mapper.lookup(matrix));
}
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
});
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
return rowSlices;
};
SmallVector<Operation*> chainOps;
Value rootValue = matrix;
while (Operation* definingOp = rootValue.getDefiningOp()) {
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
}
if (definingOp->getNumOperands() != 1)
break;
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
break;
chainOps.push_back(definingOp);
rootValue = definingOp->getOperand(0);
}
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
}
} // namespace } // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
@@ -181,23 +75,13 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC(); Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) { assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType()); auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType()); auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType.hasStaticShape()) { assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0); const int64_t numOutRows = aType.getDimSize(0);
@@ -221,43 +105,47 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) { if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = expandRankOneBias(c, expandedType, rewriter, loc); c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType; cType = expandedType;
} }
if (!cType.hasStaticShape()) { assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows; cHasNumOutRows = cType.getDimSize(0) == numOutRows;
} }
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
SmallVector<Value> cSlices;
if (hasC && cHasNumOutRows)
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
SmallVector<Value> gemvOps; SmallVector<Value> gemvOps;
gemvOps.reserve(static_cast<size_t>(numOutRows)); gemvOps.reserve(numOutRows);
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c; Value cSlice = c;
if (hasC) { if (hasC) {
if (cHasNumOutRows) if (cHasNumOutRows) {
cSlice = cSlices[static_cast<size_t>(rowIdx)]; SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
else if (!isVectorShape(getTensorShape(c))) { SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return failure(); auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
} }
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
} }
auto gemvOp = ONNXGemmOp::create(rewriter, auto gemvOp = ONNXGemmOp::create(rewriter,
loc, loc,
outRowType, outRowType,
aSlices[static_cast<size_t>(rowIdx)], aSlice,
b, b,
cSlice, cSlice,
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
@@ -268,7 +156,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs)); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
}); });
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -300,31 +189,20 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) { if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc); c = tensor::ExpandShapeOp::create(rewriter,
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType; cType = expandedType;
} }
if (!cType.hasStaticShape()) { assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
} }
if (!aType.hasStaticShape()) { assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv // Not a gemv
@@ -332,14 +210,13 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
if (transA) { if (transA) {
auto aShape = aType.getShape(); auto aShape = aType.getShape();
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType()); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc); a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
aType = cast<RankedTensorType>(a.getType());
} }
if (transB) { if (transB) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType()); bType = cast<RankedTensorType>(b.getType());
} }
@@ -363,6 +240,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices; auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices; auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize; auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices; auto outNumHSlices = cNumHSlices;
@@ -403,25 +281,19 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute( auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult { rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back( vmmOutputs.push_back(
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) { assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
}); });
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp->getResult(0)); partialResults.push_back(computeOp.getResult(0));
} }
if (hasC) { if (hasC) {
@@ -441,134 +313,15 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto concatComputeOp = auto concatComputeOp =
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs)); auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
}); });
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
} }
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1)
return failure();
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}
(void) bType;
if (!isHostFoldableValue(b))
return failure();
Value sharedBias;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
auto cType = cast<RankedTensorType>(c.getType());
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = expandRankOneBias(c, expandedType, rewriter, loc);
cType = cast<RankedTensorType>(c.getType());
}
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure();
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
sharedBias = c;
}
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange(resultTypes),
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights),
ValueRange(aSlices));
Block* body = rewriter.createBlock(
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
patterns.insert<GemmToManyGemv>(ctx); patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx); patterns.insert<GemvToSpatialCompute>(ctx);
} }
@@ -4,9 +4,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,102 +14,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) { struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
};
if (isHostFoldableValue(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -120,125 +24,80 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape()) || !outType.hasStaticShape())
return failure(); return failure();
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3) if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure(); return failure();
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1; const int64_t batch = rhsType.getDimSize(0);
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1; const int64_t k = rhsType.getDimSize(1);
const int64_t batch = std::max(lhsBatch, rhsBatch); const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch)) if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure(); return failure();
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
if (k != rhsK)
return failure();
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
return failure();
}
Location loc = matmulOp.getLoc(); Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhs = matmulOp.getA(); Value lhsTransposed =
Value rhs = matmulOp.getB(); ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
if (outType.getRank() == 2) { SmallVector<Value> gemmRows;
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); gemmRows.reserve(batch * n);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
SmallVector<Value> batchResults;
batchResults.reserve(batch);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); for (int64_t colIdx = 0; colIdx < n; colIdx++) {
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); SmallVector<OpFoldResult> offsets = {
Value gemmResult = ONNXGemmOp::create(rewriter, rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
loc, SmallVector<OpFoldResult> sizes = {
gemmType, rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
lhsMatrix, SmallVector<OpFoldResult> strides = {
rhsMatrix, rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
none, Value rhsSlice =
rewriter.getF32FloatAttr(1.0f), tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
rewriter.getF32FloatAttr(1.0f), Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
rewriter.getBoolAttr(false), loc,
rewriter.getBoolAttr(false)) rhsRowType,
.getY(); rhsSlice,
auto batchResultCompute = SmallVector<ReassociationIndices> {
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) { {0},
Value resultMatrix = input; {1, 2}
if (useTransposedForm) {
resultMatrix = ONNXTransposeOp::create(rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
input,
rewriter.getI64ArrayAttr({1, 0}));
}
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
resultMatrix,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
}); });
batchResults.push_back(batchResultCompute.getResult(0));
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
} }
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc,
gemmExpandedType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();
} }
@@ -247,7 +106,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace } // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulToGemm>(ctx); patterns.insert<MatMulRank3ToGemm>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -5,9 +5,8 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -82,24 +81,6 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input, static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes, ArrayRef<bool> reducedAxes,
int64_t axis, int64_t axis,
@@ -119,7 +100,8 @@ static Value buildReduceMeanKeepdims(Value input,
for (Value slice : slices) for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc); return reducedSlices.size() == 1 ? reducedSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
} }
static Value squeezeReducedAxes(Value keepdimsValue, static Value squeezeReducedAxes(Value keepdimsValue,
@@ -134,16 +116,9 @@ static Value squeezeReducedAxes(Value keepdimsValue,
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
} }
auto reassociation = buildCollapseReassociation(reducedAxes); return tensor::CollapseShapeOp::create(
if (isHostFoldableValue(keepdimsValue)) rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); .getResult();
auto squeezeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
});
return squeezeCompute.getResult(0);
} }
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> { struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
@@ -1,20 +1,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,6 +31,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
} }
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType()); auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
@@ -47,126 +52,27 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
} }
static Value createPoolFillElement( template <typename ReduceOp>
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
if (!useMinimumValue) assert(!windowValues.empty() && "Expected at least one pool window value.");
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
if (auto floatType = dyn_cast<FloatType>(elementType)) { Value reduced = windowValues.front();
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); for (Value value : windowValues.drop_front())
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
} return reduced;
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
}
llvm_unreachable("unsupported pool element type");
} }
static Value createPoolFillTensor( static Value
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) { scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue); assert(divisor > 0 && "AveragePool divisor must be positive.");
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement); if (divisor == 1)
} return reducedWindow;
template <typename PoolOp> auto tileType = cast<RankedTensorType>(reducedWindow.getType());
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter, double scale = 1.0 / static_cast<double>(divisor);
Location loc, auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
PoolOp poolOp, Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
Value input, return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
RankedTensorType inputType,
int64_t padTop,
int64_t padLeft,
int64_t padBottom,
int64_t padRight) {
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
return input;
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
inputType.getDimSize(1),
inputType.getDimSize(2) + padTop + padBottom,
inputType.getDimSize(3) + padLeft + padRight},
inputType.getElementType(),
inputType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padTop),
rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padBottom),
rewriter.getIndexAttr(padRight)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
auto* padBlock = new Block();
for (int index = 0; index < paddedType.getRank(); ++index)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
Value padValue = createPoolFillElement(
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
tensor::YieldOp::create(rewriter, loc, padValue);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
Location loc,
Operation* op,
RankedTensorType outType,
int64_t channels,
int64_t inputHeight,
int64_t inputWidth,
int64_t outputHeight,
int64_t outputWidth,
int64_t kernelHeight,
int64_t kernelWidth,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t padTop,
int64_t padLeft,
bool countIncludePad) {
auto elemType = dyn_cast<FloatType>(outType.getElementType());
if (!elemType) {
op->emitOpError("AveragePool lowering requires a floating-point element type");
return failure();
}
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
SmallVector<Attribute> scaleValues;
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
for (int64_t channel = 0; channel < channels; ++channel) {
(void) channel;
for (int64_t outH = 0; outH < outputHeight; ++outH) {
for (int64_t outW = 0; outW < outputWidth; ++outW) {
int64_t validCount = 0;
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
++validCount;
}
}
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
}
}
}
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
} }
template <typename PoolOp> template <typename PoolOp>
@@ -244,144 +150,89 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
} }
} }
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue()); const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
const bool countIncludePad = [&]() {
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
return poolOp.getCountIncludePad() == 1;
return true;
}();
Value averageScaleTensor;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
loc,
poolOp,
outType,
channels,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kernelHeight,
kernelWidth,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
padTop,
padLeft,
countIncludePad);
if (failed(maybeAverageScaleTensor))
return failure();
averageScaleTensor = *maybeAverageScaleTensor;
}
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); SmallVector<Value> batchResults;
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); batchResults.reserve(batchSize);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); for (int64_t batch = 0; batch < batchSize; ++batch) {
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); SmallVector<Value> rows;
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); rows.reserve(outputHeight);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); for (int64_t outH = 0; outH < outputHeight; ++outH) {
rewriter.setInsertionPointToStart(outputLoop.getBody()); SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
Value outputPatchIndex = outputLoop.getInductionVar(); for (int64_t outW = 0; outW < outputWidth; ++outW) {
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
Value updatedOutput = pooledOutputAcc; SmallVector<Value> windowValues;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { windowValues.reserve(kernelHeight * kernelWidth);
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
Value reducedWindow = createPoolFillTensor( if (inH < 0 || inH >= inputHeight)
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>); continue;
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value paddedInH = windowBaseH; const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (kernelH * dilationHeight != 0) { if (inW < 0 || inW >= inputWidth)
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); continue;
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
}
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
Value paddedInW = windowBaseW; rewriter.getIndexAttr(channelTile * xbarSize),
if (kernelW * dilationWidth != 0) { rewriter.getIndexAttr(inH),
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); rewriter.getIndexAttr(inW)};
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
} }
SmallVector<OpFoldResult> offsets = {batchIndex, rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
rewriter.getIndexAttr(channelTile * xbarSize),
paddedInH,
paddedInW};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
} }
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
} }
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
} }
scf::YieldOp::create(rewriter, loc, updatedOutput); Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.setInsertionPointAfter(outputLoop);
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,9 +1,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,24 +32,6 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
@@ -66,7 +47,8 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
for (Value slice : slices) for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return concatValues(rebuiltSlices, axis, rewriter, loc); return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
} }
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
@@ -111,13 +93,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value transposedInput = preTransposeCompute.getResult(0); Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax( Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
auto postTransposeCompute = result = ONNXTransposeOp::create(
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -1,10 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -20,17 +17,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs(); auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis(); int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) { rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
auto computeOp = createSpatCompute(
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
});
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
return success(); return success();
} }
@@ -5,8 +5,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
} }
if (slices.empty()) if (slices.empty())
return {}; return {};
return createSpatConcat(rewriter, loc, axis, slices); return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
} }
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
@@ -130,7 +130,9 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure(); return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
} }
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows); result = rows.size() == 1
? rows.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
} }
else { else {
return failure(); return failure();
@@ -3,10 +3,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -98,33 +95,18 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return success(); return success();
} }
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success();
};
SmallVector<ReassociationIndices> reassociation; SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank() if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) && inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
return replaceWithReshape([&](Value data) { rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); return success();
}); }
if (sourceType.getRank() < resultType.getRank() if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) && inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
return replaceWithReshape([&](Value data) { rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); return success();
}); }
return failure(); return failure();
} }
@@ -5,8 +5,8 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
} }
return createSpatConcat(rewriter, loc, axis, slices); return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
} }
struct Resize : OpConversionPattern<ONNXResizeOp> { struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -1,10 +1,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -25,10 +23,7 @@ static Value extractSliceAt(
sizes.push_back(rewriter.getIndexAttr(dim)); sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset); offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size); sizes[axis] = rewriter.getIndexAttr(size);
SmallVector<int64_t> resultShape(inputType.getShape()); return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
} }
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
@@ -49,40 +44,21 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
outputs.reserve(splitOp.getNumResults()); outputs.reserve(splitOp.getNumResults());
int64_t offset = 0; int64_t offset = 0;
SmallVector<RankedTensorType> resultTypes;
resultTypes.reserve(splitOp.getNumResults());
SmallVector<int64_t> sliceSizes;
sliceSizes.reserve(splitOp.getNumResults());
for (Value result : splitOp.getResults()) { for (Value result : splitOp.getResults()) {
auto resultType = dyn_cast<RankedTensorType>(result.getType()); auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
resultTypes.push_back(resultType); int64_t sliceSize = resultType.getShape()[axis];
sliceSizes.push_back(resultType.getShape()[axis]); auto computeOp =
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
});
outputs.push_back(computeOp.getResult(0));
offset += sliceSize;
} }
if (isHostFoldableValue(adaptor.getInput())) { rewriter.replaceOp(splitOp, outputs);
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize;
}
rewriter.replaceOp(splitOp, outputs);
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
SmallVector<Value> runtimeOutputs;
runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
runtimeOffset += sliceSize;
}
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
});
rewriter.replaceOp(splitOp, computeOp.getResults());
return success(); return success();
} }
}; };
@@ -1,265 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
} // namespace onnx_mlir
@@ -1,14 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
namespace onnx_mlir {
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,25 +0,0 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -1,218 +0,0 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
loc,
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
ValueRange(batchWeights),
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
Block* newBlock =
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
}
Operation* cloned = rewriter.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
rewriter.setInsertionPointToEnd(newBlock);
PimHaltOp::create(rewriter, loc);
return success();
}
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -4,16 +4,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -1,136 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
if (op->use_empty()) {
rewriter.eraseOp(op);
return success();
}
auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = pim::PimReceiveOp::create(rewriter,
op.getLoc(),
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
auto inputType = cast<RankedTensorType>(op.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(op.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value concatenated =
pim::PimConcatOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
.getOutput();
rewriter.replaceOp(op, concatenated);
return success();
}
};
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
+45 -14
View File
@@ -7,12 +7,23 @@
#include <cstddef> #include <cstddef>
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/* /*
EXAMPLE RUN: EXAMPLE RUN:
@@ -63,6 +74,37 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
} }
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) { Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers(); auto users = value.getUsers();
@@ -85,26 +127,15 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
} }
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) { mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1); assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0); mlir::Value result = operation->getResult(0);
auto resultType = result.getType(); auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType)); assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation); SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) { auto validOperands =
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation); make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
});
auto bestOperand = validOperands.begin(); auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end()) if (bestOperand != validOperands.end())
+18 -1
View File
@@ -2,10 +2,16 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
/** /**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and * \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input. * its static tensor input.
@@ -24,6 +30,17 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T> template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) { size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end()); return std::distance(range.begin(), range.end());
@@ -41,7 +58,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation); mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation); mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
@@ -1,44 +0,0 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = op->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
}
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Operation* owner,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
rewriter.finalizeOpModification(owner);
}
} // namespace onnx_mlir
@@ -1,17 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
mlir::Operation* owner,
unsigned inputIndex,
mlir::Value replacement);
} // namespace onnx_mlir
@@ -1,213 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
return static_cast<int32_t>(fallbackCoreId++);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
return success();
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return success();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
continue;
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<spatial::SpatChannelSendOp>(resultUser))
continue;
}
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
}
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
SmallVector<Value> computeWeights;
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
return success();
}
} // namespace onnx_mlir
@@ -1,21 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,390 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
name = (baseName + "_" + Twine(suffix++)).str();
return name;
}
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
Location loc,
ModuleOp moduleOp,
StringRef baseName,
MemRefType type,
Attribute initialValue = {},
UnitAttr constant = {}) {
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(type),
initialValue,
constant,
IntegerAttr {});
}
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
return failure();
}
}
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
rewriter.eraseOp(extractSliceOp);
return success();
}
};
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
if (funcOp.getArguments().empty())
return failure();
if (llvm::all_of(funcOp.getArguments(),
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
return failure();
Location loc = funcOp.getLoc();
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
if (arg.getUses().empty())
continue;
rewriter.setInsertionPoint(funcOp.getOperation());
assert(isa<mlir::RankedTensorType>(arg.getType()));
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string baseName = ("arg_" + Twine(index)).str();
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
std::string argName = globalOp.getSymName().str();
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
}
else {
rewriter.setInsertionPoint(argUser);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(argUser);
argUses.set(toTensor);
rewriter.finalizeOpModification(argUser);
}
}
}
return success();
}
};
} // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -1,20 +0,0 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
hasFailure = true;
});
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
@@ -1,587 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
struct ReturnUseInfo {
size_t returnIndex;
SmallVector<Operation*> helperChain;
};
struct ConcatReturnUseInfo {
size_t returnIndex;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> concatShape;
SmallVector<Operation*> concatChain;
SmallVector<Operation*> helperChain;
};
static bool isReturnHelperChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
name = (baseName + "_" + Twine(suffix++)).str();
return name;
}
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
int64_t flatIndex = 0;
for (size_t i = 0; i < shape.size(); ++i) {
flatIndex *= shape[i];
flatIndex += indices[i];
}
return flatIndex;
}
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin()))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
auto uses = value.getUses();
if (rangeLength(uses) != 1)
return std::nullopt;
SmallVector<Operation*> helperChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(helperChain),
};
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation* op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getOutput();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
};
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getAxis();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
};
auto getConcatOperands = [](Operation* op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getInputs();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return std::nullopt;
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
SmallVector<Operation*> concatChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
concatChain.push_back(currentUser);
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
Value concatResult = getConcatResult(currentUser);
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatResult;
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
return std::nullopt;
currentValue = helperCompute.getResult(0);
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ConcatReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(sliceOffsets),
std::move(concatShape),
std::move(concatChain),
std::move(helperChain),
};
}
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
ArrayRef<int64_t> sourceShape,
ArrayRef<Operation*> helperChain,
SmallVectorImpl<int64_t>& mappedIndices) {
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
return success();
};
for (Operation* op : helperChain) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
return failure();
SmallVector<int64_t> nextIndices;
nextIndices.reserve(currentIndices.size());
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
if (stride != 1 || index < offset || index >= offset + size)
return failure();
nextIndices.push_back(index - offset);
}
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
currentIndices = std::move(nextIndices);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
if (failed(reshapeToResultShape(op)))
return failure();
continue;
}
return failure();
}
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static Value emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
} // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp,
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue;
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
}
else {
auto outRankedTensorType = llvm::dyn_cast<RankedTensorType>(currentReturnValue.getType());
auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputBaseName = ("output_" + Twine(index)).str();
std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType<ModuleOp>(), outputBaseName);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
}
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.reserve(storedType.getRank());
extractSizes.reserve(storedType.getRank());
extractStrides.reserve(storedType.getRank());
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
extractOffsets.push_back(rewriter.getIndexAttr(idx));
extractSizes.push_back(rewriter.getIndexAttr(1));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
outputTensor = emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
}
return ReturnPathLoweringResult::Handled;
}
return ReturnPathLoweringResult::NotReturnPath;
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
bool isExclusivelyOwnedByReturnChain = op->use_empty();
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(state, op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
}
} // namespace onnx_mlir
@@ -1,37 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include <functional>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
struct ReturnPathState {
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
mlir::IRRewriter& rewriter,
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def HasSpatialChannelSourceCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
"spatial channel has precomputed source core id">;
def HasSpatialChannelTargetCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
"spatial channel has precomputed target core id">;
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
def onnxToPimTranspose : Pat< def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms), (ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms, (PimTransposeOp $data, $perms,
@@ -16,11 +27,17 @@ def onnxToPimTranspose : Pat<
>; >;
def spatToPimVMM : Pat< def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weightIndex, $vector), (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimMVM : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAdd : Pat< def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b), (SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b, (PimVVAddOp $a, $b,
@@ -63,4 +80,18 @@ def spatToPimVSoftmax : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input),
(PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
[(HasSpatialChannelTargetCoreIdAttr $channel)]
>;
def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel),
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
[(HasSpatialChannelSourceCoreIdAttr $channel)]
>;
#endif // SPATIAL_TO_PIM #endif // SPATIAL_TO_PIM

Some files were not shown because too many files have changed in this diff Show More