Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 628dc630a4 | |||
| 80a7298552 | |||
| 8ad504fcdf | |||
| e6f442c5d2 | |||
| f6b97b3813 | |||
| 26317ea7d0 | |||
| 909c4acfdd | |||
| feaff820e1 | |||
| 1e279ae9bb | |||
| 57f0cca8c0 | |||
| 5ff364027b | |||
| b1272d2283 | |||
| 58e6587697 | |||
| f6c8cc4aa5 | |||
| 566630b99a | |||
| 74931ad75b | |||
| f2fe147961 | |||
| 7bb58e80de | |||
| b2dc9c38b6 | |||
| 3cb6a1abc5 | |||
| 285773fa55 | |||
| bdacb9871d | |||
| 5b9bb0c191 | |||
| f789954ad7 | |||
| b6ba1e4fea | |||
| 717ad160cd | |||
| 905fa9f9a7 | |||
| 62b0a6e19d | |||
| b605585b1f | |||
| 08b0fcd850 | |||
| 9dccc2c701 | |||
| 5c839e62c1 | |||
| 15e8edb9c4 | |||
| 951baca106 | |||
| fc5bccb487 | |||
| 49dea15b95 | |||
| 5545b0f672 | |||
| cff929a083 | |||
| 89b3501aa8 | |||
| 412ca957f6 |
+12
@@ -1,5 +1,17 @@
|
|||||||
|
.zed
|
||||||
.idea
|
.idea
|
||||||
**/.vscode
|
**/.vscode
|
||||||
|
|
||||||
.claude
|
.claude
|
||||||
|
.codex
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
|
|
||||||
|
CMakeUserPresets.json
|
||||||
|
|
||||||
build
|
build
|
||||||
|
build_release
|
||||||
|
cmake-build-debug
|
||||||
|
cmake-build-release
|
||||||
|
compile.sh
|
||||||
|
|
||||||
|
**/__*
|
||||||
|
|||||||
@@ -1,5 +1,159 @@
|
|||||||
# Raptor
|
# Raptor
|
||||||
|
|
||||||
|
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
|
||||||
|
targeting in-memory computing / processing-in-memory (PIM) architectures.
|
||||||
|
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
|
||||||
|
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
PIM architectures perform most of the computation directly in memory.
|
||||||
|
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
|
||||||
|
- a shared host memory,
|
||||||
|
- a number of cores that do most of the computation directly in their memory
|
||||||
|
(vector ops, vmm/mvm on ReRAM crossbars),
|
||||||
|
- no branching instructions (branchless architecture) and no hardware loop
|
||||||
|
support — any repeated work (e.g. convolutions) must be unrolled into
|
||||||
|
explicit per-iteration instructions.
|
||||||
|
|
||||||
|
Because of this, the amount of emitted instructions explodes quickly and the
|
||||||
|
compiler must optimize aggressively at every stage to keep compilation
|
||||||
|
tractable.
|
||||||
|
|
||||||
|
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
|
||||||
|
each carrying its own in-memory computing unit and crossbars. It will live in
|
||||||
|
a dedicated dialect (future work).
|
||||||
|
|
||||||
|
### Targets and simulators
|
||||||
|
|
||||||
|
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
|
||||||
|
**performance** estimates (latency, energy), but does not functionally execute
|
||||||
|
the JSON code it consumes. To validate the numerical correctness of the JSON
|
||||||
|
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
|
||||||
|
a Rust simulator we maintain in-tree at
|
||||||
|
`backend-simulators/pim/pim-simulator`.
|
||||||
|
|
||||||
|
## Compilation pipeline
|
||||||
|
|
||||||
|
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
|
||||||
|
When working on this codebase, most changes should stay confined to those
|
||||||
|
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
|
||||||
|
framework-level details).
|
||||||
|
|
||||||
|
High-level lowering flow:
|
||||||
|
|
||||||
|
```
|
||||||
|
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||||
|
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
|
||||||
|
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
|
||||||
|
operations are accelerated by storing a constant RHS matrix into a
|
||||||
|
crossbar. Crossbars cannot be re-programmed during execution, have a
|
||||||
|
limited fixed size, and there is a limited number of them per core.
|
||||||
|
Conversion patterns are split by op family under
|
||||||
|
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
|
||||||
|
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
|
||||||
|
Reshape, Resize, Split).
|
||||||
|
|
||||||
|
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||||
|
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
|
||||||
|
materializes PIM cores (`pim.core`), inter-core communication
|
||||||
|
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
|
||||||
|
|
||||||
|
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||||
|
A DCP-inspired heuristic (Dynamic Critical Path — see the original
|
||||||
|
scheduling paper by Kwok & Ahmad,
|
||||||
|
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
|
||||||
|
that coarsens the virtual node graph and decides how to group compute
|
||||||
|
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
|
||||||
|
heuristic with different assumptions from the paper (different cost
|
||||||
|
model, constraints from crossbar capacity / core resources, and a
|
||||||
|
windowed coarsening loop instead of full-graph reprioritization). The
|
||||||
|
`dcp-critical-window-size` option controls how many lowest-slack virtual
|
||||||
|
nodes each coarsening iteration considers (0 = legacy full-graph
|
||||||
|
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
|
||||||
|
`MergeComputeNodesPass.cpp`.
|
||||||
|
|
||||||
|
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
||||||
|
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
|
||||||
|
standard MLIR `BufferizableOpInterface` machinery
|
||||||
|
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
||||||
|
|
||||||
|
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||||
|
- `HostConstantFolding` — folds host-side constants.
|
||||||
|
- `MaterializeHostConstantsPass` — materializes the remaining host
|
||||||
|
constants for emission.
|
||||||
|
- `VerificationPass` — checks invariants before emission.
|
||||||
|
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
|
||||||
|
and `pim-simulator`.
|
||||||
|
|
||||||
|
Supporting pieces:
|
||||||
|
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
|
||||||
|
core count, DCP window, experimental conv impl, concat error handling, …)
|
||||||
|
and `PimCodeGen` entry points.
|
||||||
|
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
||||||
|
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
||||||
|
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
||||||
|
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
||||||
|
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
||||||
|
|
||||||
|
## Key compiler options
|
||||||
|
|
||||||
|
Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||||
|
|
||||||
|
- `--maccel=PIM` — select the PIM accelerator.
|
||||||
|
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
|
||||||
|
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
|
||||||
|
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
|
||||||
|
run only the codegen tail.
|
||||||
|
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||||
|
per-core count.
|
||||||
|
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||||
|
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||||
|
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||||
|
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
Functional validation lives in `validation/` and drives the Rust
|
||||||
|
`pim-simulator` to compare Raptor's output against a reference.
|
||||||
|
|
||||||
|
Per-operation validation (from `validation/`):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include
|
||||||
|
```
|
||||||
|
|
||||||
|
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--operations-dir ./networks/yolo11n/depth_04 \
|
||||||
|
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||||
|
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||||
|
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||||
|
`sigmoid`, `softmax`, `split`.
|
||||||
|
|
||||||
|
## Rebuilding
|
||||||
|
|
||||||
|
Release build (fast):
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
|
||||||
|
```
|
||||||
|
|
||||||
|
A slower debug build is also available — configure it the same way but with
|
||||||
|
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Protobuf
|
### Protobuf
|
||||||
|
|||||||
+2102
-8
File diff suppressed because it is too large
Load Diff
@@ -13,8 +13,9 @@ name = "pimcore"
|
|||||||
path = "src/lib/pimcore.rs"
|
path = "src/lib/pimcore.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["tracing"]
|
default = []
|
||||||
tracing = []
|
tracing = []
|
||||||
|
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -27,3 +28,9 @@ hex = "0"
|
|||||||
paste = "1"
|
paste = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
statrs = {version="0.16", optional=true}
|
||||||
|
comfy-table = {version="7.1", optional=true}
|
||||||
|
plotly = {version="0.8", optional=true}
|
||||||
|
rayon = "1.12.0"
|
||||||
|
faer = "0.24.0"
|
||||||
|
faer-traits = "0.24.0"
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
cpu::{CPU, crossbar}, instruction_set::{
|
cpu::{CPU, crossbar},
|
||||||
|
instruction_set::{
|
||||||
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
|
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
|
||||||
helper::add_all,
|
helper::add_all,
|
||||||
}, memory_manager::{
|
},
|
||||||
|
memory_manager::{
|
||||||
MemoryStorable,
|
MemoryStorable,
|
||||||
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||||
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
|
},
|
||||||
|
tracing::TRACER,
|
||||||
|
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
|
||||||
};
|
};
|
||||||
use aligned_vec::{AVec, ConstAlign};
|
use aligned_vec::{AVec, ConstAlign};
|
||||||
use anyhow::{Context, Result, ensure};
|
use anyhow::{Context, Result, ensure};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use paste::paste;
|
use paste::paste;
|
||||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
|
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
|
||||||
@@ -76,8 +81,7 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
|||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
{
|
|
||||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
let core = cores.core(core_indx);
|
let core = cores.core(core_indx);
|
||||||
@@ -229,25 +233,30 @@ where
|
|||||||
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
[M]: UpcastSlice<T>,
|
[M]: UpcastSlice<T>,
|
||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
// Add faer::ComplexField HERE, directly bounding M for this function only
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
|
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
|
||||||
|
|
||||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||||
let group: usize = group.try_into().context("group can not be negative")?;
|
let group: usize = group.try_into().context("group can not be negative")?;
|
||||||
|
|
||||||
let core = cores.core(core_indx);
|
let core = cores.core(core_indx);
|
||||||
let r1_val = core.register(r1);
|
let r1_val = core.register(r1);
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
|
|
||||||
let (memory, crossbars) = core.get_memory_crossbar();
|
let (memory, crossbars) = core.get_memory_crossbar();
|
||||||
let crossbar = crossbars.get_mut(group).unwrap();
|
let crossbar = crossbars.get_mut(group).unwrap();
|
||||||
let crossbar_stored_bytes = crossbar.stored_bytes();
|
let crossbar_stored_bytes = crossbar.stored_bytes();
|
||||||
let crossbar_byte_width = crossbar.width();
|
let crossbar_byte_width = crossbar.width();
|
||||||
//Fix this
|
|
||||||
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
||||||
ensure!(
|
ensure!(
|
||||||
crossbar_byte_width & size_of::<M>() == 0,
|
crossbar_byte_width % size_of::<M>() == 0,
|
||||||
"M not divisor of the crosbbar size"
|
"M not divisor of the crosbbar size"
|
||||||
);
|
);
|
||||||
|
|
||||||
let crossbar_height = crossbar.height();
|
let crossbar_height = crossbar.height();
|
||||||
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
||||||
|
|
||||||
@@ -257,19 +266,29 @@ where
|
|||||||
let load = loads[0];
|
let load = loads[0];
|
||||||
let vec: Cow<[M]> = load.up();
|
let vec: Cow<[M]> = load.up();
|
||||||
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
||||||
let mut res = Vec::with_capacity(crossbar_elem_width);
|
|
||||||
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
|
|
||||||
partial.resize(vec.len(), M::from_f32(0.0));
|
|
||||||
|
|
||||||
for x in 0..crossbar_elem_width {
|
// --- FAER IMPLEMENTATION ---
|
||||||
partial[0] = vec[0] * matrix[x];
|
|
||||||
for y in 1..crossbar_height {
|
// 1. Explicitly create a Matrix Reference (MatRef)
|
||||||
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
|
let matrix_view = faer::mat::MatRef::from_row_major_slice(
|
||||||
}
|
matrix.as_ref(),
|
||||||
|
crossbar_height,
|
||||||
|
crossbar_elem_width,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Explicitly create a Column Vector Reference (ColRef)
|
||||||
|
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
|
||||||
|
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
|
||||||
|
|
||||||
|
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
|
||||||
|
|
||||||
|
// 4. Convert back to standard Rust Vec
|
||||||
|
// try_as_slice() returns an Option<&[M]>.
|
||||||
|
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
|
||||||
|
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
|
||||||
|
|
||||||
|
// --- END FAER ---
|
||||||
|
|
||||||
let mut acc = add_all(partial.as_slice());
|
|
||||||
res.push(acc);
|
|
||||||
}
|
|
||||||
if relu != 0 {
|
if relu != 0 {
|
||||||
res.iter_mut().for_each(|x| {
|
res.iter_mut().for_each(|x| {
|
||||||
if *x < M::from_f32(0.0) {
|
if *x < M::from_f32(0.0) {
|
||||||
@@ -277,12 +296,15 @@ where
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
ensure!(
|
ensure!(
|
||||||
res.len() == crossbar_elem_width,
|
res.len() == crossbar_elem_width,
|
||||||
"mvm generate a vector bigger thant it's requested elements"
|
"mvm generate a vector bigger thant it's requested elements"
|
||||||
);
|
);
|
||||||
|
|
||||||
let res_up: Cow<[T]> = res.as_slice().up();
|
let res_up: Cow<[T]> = res.as_slice().up();
|
||||||
core.execute_store(rd_val, res_up.as_ref());
|
core.execute_store(rd_val, res_up.as_ref());
|
||||||
|
|
||||||
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
|
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
|
||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
@@ -533,7 +555,10 @@ where
|
|||||||
let r2_val = r2;
|
let r2_val = r2;
|
||||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
ensure!(
|
||||||
|
offset_select == 1,
|
||||||
|
"Offset select cannot be different from 1"
|
||||||
|
);
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||||
let load1 = loads[0];
|
let load1 = loads[0];
|
||||||
@@ -633,7 +658,10 @@ pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionSta
|
|||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vsoftmax_impl<F, T>(
|
||||||
|
cores: &mut CPU,
|
||||||
|
data: InstructionData,
|
||||||
|
) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
@@ -656,12 +684,11 @@ where
|
|||||||
.reduce(|a, b| if a > b { a } else { b })
|
.reduce(|a, b| if a > b { a } else { b })
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
||||||
let sum = exp_values
|
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
|
||||||
.iter()
|
ensure!(
|
||||||
.copied()
|
sum > 0.0.into(),
|
||||||
.reduce(|a, b| a + b)
|
"vsoftmax normalization sum must be positive"
|
||||||
.unwrap();
|
);
|
||||||
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
|
|
||||||
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
||||||
let res_up: Cow<[T]> = res.as_slice().up();
|
let res_up: Cow<[T]> = res.as_slice().up();
|
||||||
core.execute_store(rd_val, res_up.as_ref());
|
core.execute_store(rd_val, res_up.as_ref());
|
||||||
@@ -749,12 +776,10 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_send(cores, data);
|
|
||||||
Ok(InstructionStatus::Sending(data))
|
Ok(InstructionStatus::Sending(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_recv(cores, data);
|
|
||||||
Ok(InstructionStatus::Reciving(data))
|
Ok(InstructionStatus::Reciving(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -55,17 +55,25 @@ pub trait HasSigm {
|
|||||||
|
|
||||||
impl HasSigm for f32 {
|
impl HasSigm for f32 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
|
if self >= 0.0 {
|
||||||
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
let ex = self.exp();
|
let ex = self.exp();
|
||||||
ex / (1.0 + ex)
|
ex / (1.0 + ex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl HasSigm for f64 {
|
impl HasSigm for f64 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
|
if self >= 0.0 {
|
||||||
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
let ex = self.exp();
|
let ex = self.exp();
|
||||||
ex / (1.0 + ex)
|
ex / (1.0 + ex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait HasExp {
|
pub trait HasExp {
|
||||||
fn exp(self) -> Self;
|
fn exp(self) -> Self;
|
||||||
|
|||||||
@@ -169,6 +169,9 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
print_status(cores_instructions);
|
print_status(cores_instructions);
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
TRACER.lock().unwrap().report();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cpu(&self) -> &CPU<'a> {
|
pub fn cpu(&self) -> &CPU<'a> {
|
||||||
|
|||||||
@@ -58,6 +58,20 @@ where 'a : 'b
|
|||||||
&& sender.internal_core == receiver.external_core
|
&& sender.internal_core == receiver.external_core
|
||||||
&& receiver.internal_core == sender.external_core
|
&& receiver.internal_core == sender.external_core
|
||||||
{
|
{
|
||||||
|
{
|
||||||
|
let sender = &mut core_instructions[sender.internal_core];
|
||||||
|
let pc = sender.program_counter;
|
||||||
|
let inst = sender.instructions.get(pc).unwrap();
|
||||||
|
let data = inst.data;
|
||||||
|
TRACER.lock().unwrap().pre_send(cpu, data);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let recv = &mut core_instructions[receiver.internal_core];
|
||||||
|
let pc = recv.program_counter;
|
||||||
|
let inst = recv.instructions.get(pc).unwrap();
|
||||||
|
let data = inst.data;
|
||||||
|
TRACER.lock().unwrap().pre_recv(cpu, data);
|
||||||
|
}
|
||||||
let [sender_core, reciver_core] =
|
let [sender_core, reciver_core] =
|
||||||
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
|
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
|
||||||
let memory = sender_core
|
let memory = sender_core
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
impl Trace {
|
impl Trace {
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
|
|||||||
@@ -1,52 +1,32 @@
|
|||||||
mod tracing_isa;
|
|
||||||
mod disable;
|
mod disable;
|
||||||
mod pretty_print;
|
#[cfg(feature = "profile_time")]
|
||||||
use std::{fs::File, path::{ PathBuf}};
|
mod profile;
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
use profile::Trace;
|
||||||
|
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
mod trace;
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
use trace::Trace;
|
||||||
|
|
||||||
|
use crate::Executable;
|
||||||
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::{LazyLock, Mutex};
|
use std::sync::{LazyLock, Mutex};
|
||||||
|
|
||||||
|
|
||||||
use crate::Executable;
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
|
pub struct Trace {}
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
pub struct Trace {
|
|
||||||
out_files : Vec<File>
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
|
||||||
impl Trace {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self { out_files : Vec::new()}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
|
|
||||||
path.pop();
|
|
||||||
for i in 0..num_core {
|
|
||||||
path.push(format!("TraceCore{}", i));
|
|
||||||
let file = File::create(&path).expect("Can not create file");
|
|
||||||
self.out_files.push(file);
|
|
||||||
path.pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
|
||||||
pub struct Trace {
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
|
||||||
impl Trace {
|
impl Trace {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {}
|
Self {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
|
||||||
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
|
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
use std::{collections::HashMap, path::PathBuf, time::Instant};
|
||||||
|
|
||||||
|
use crate::tracing::profile::profile_analysis::{
|
||||||
|
analyze_timings, generate_interactive_report, print_textual_report,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod profile_analysis;
|
||||||
|
pub mod profile_isa;
|
||||||
|
|
||||||
|
pub struct Trace {
|
||||||
|
instruction_times: HashMap<String, Vec<(u128,u128)>>,
|
||||||
|
core_start_time: HashMap<usize, Option<Instant>>,
|
||||||
|
start_time: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Trace {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let mut instruction_times = HashMap::new();
|
||||||
|
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
|
||||||
|
Self {
|
||||||
|
instruction_times,
|
||||||
|
core_start_time: HashMap::new(),
|
||||||
|
start_time: Instant::now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, path: PathBuf) {
|
||||||
|
for i in 0..num_core {
|
||||||
|
self.core_start_time.insert(i, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn report(&self) {
|
||||||
|
let res = analyze_timings(&self.instruction_times);
|
||||||
|
print_textual_report(&res);
|
||||||
|
generate_interactive_report(
|
||||||
|
&self.instruction_times,
|
||||||
|
&["mvmul", "recv"],
|
||||||
|
"/tmp/report.html",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
|
||||||
|
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InstructionStats {
|
||||||
|
pub name: String,
|
||||||
|
pub count: usize,
|
||||||
|
pub total_time: u128,
|
||||||
|
pub min: f64,
|
||||||
|
pub max: f64,
|
||||||
|
pub mean: f64,
|
||||||
|
pub median: f64,
|
||||||
|
pub std_dev: f64,
|
||||||
|
pub cv: f64,
|
||||||
|
pub p95: f64,
|
||||||
|
pub p99: f64,
|
||||||
|
pub skewness: f64,
|
||||||
|
pub kurtosis: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_time(ns: f64) -> String {
|
||||||
|
if ns.is_nan() {
|
||||||
|
return "NaN".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ns >= 1_000_000_000.0 {
|
||||||
|
format!("{:.2} s", ns / 1_000_000_000.0)
|
||||||
|
} else if ns >= 1_000_000.0 {
|
||||||
|
format!("{:.2} ms", ns / 1_000_000.0)
|
||||||
|
} else if ns >= 1_000.0 {
|
||||||
|
format!("{:.2} µs", ns / 1_000.0)
|
||||||
|
} else {
|
||||||
|
format!("{:.2} ns", ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
|
||||||
|
let n = times.len() as f64;
|
||||||
|
|
||||||
|
if n < 4.0 || std_dev == 0.0 {
|
||||||
|
return (f64::NAN, f64::NAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum_m3 = 0.0;
|
||||||
|
let mut sum_m4 = 0.0;
|
||||||
|
|
||||||
|
for &x in times {
|
||||||
|
let deviation = x - mean;
|
||||||
|
sum_m3 += deviation.powi(3);
|
||||||
|
sum_m4 += deviation.powi(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
let m3 = sum_m3 / n;
|
||||||
|
let m4 = sum_m4 / n;
|
||||||
|
|
||||||
|
let skewness = m3 / std_dev.powi(3);
|
||||||
|
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
|
||||||
|
|
||||||
|
(skewness, kurtosis)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for (instruction, times) in timings {
|
||||||
|
let count = times.len();
|
||||||
|
if count == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract ONLY the duration (the second element of the tuple) for stats
|
||||||
|
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
|
||||||
|
let total_time: u128 = durations.iter().sum();
|
||||||
|
|
||||||
|
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
|
||||||
|
let mut data = Data::new(f64_times.clone());
|
||||||
|
|
||||||
|
let mean = data.mean().unwrap_or(0.0);
|
||||||
|
let std_dev = data.std_dev().unwrap_or(0.0);
|
||||||
|
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
|
||||||
|
|
||||||
|
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
|
||||||
|
|
||||||
|
results.push(InstructionStats {
|
||||||
|
name: instruction.clone(),
|
||||||
|
count,
|
||||||
|
total_time,
|
||||||
|
min: data.min(),
|
||||||
|
max: data.max(),
|
||||||
|
mean,
|
||||||
|
median: data.median(),
|
||||||
|
std_dev,
|
||||||
|
cv,
|
||||||
|
p95: data.percentile(95),
|
||||||
|
p99: data.percentile(99),
|
||||||
|
skewness,
|
||||||
|
kurtosis,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_textual_report(stats: &[InstructionStats]) {
|
||||||
|
let mut table = Table::new();
|
||||||
|
table
|
||||||
|
.load_preset(UTF8_FULL)
|
||||||
|
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||||
|
.set_header(vec![
|
||||||
|
"Instruction",
|
||||||
|
"Count",
|
||||||
|
"Total Time",
|
||||||
|
"Mean",
|
||||||
|
"Median",
|
||||||
|
"Min",
|
||||||
|
"Max",
|
||||||
|
"P95",
|
||||||
|
"P99",
|
||||||
|
"StdDev",
|
||||||
|
"CV",
|
||||||
|
"Skewness",
|
||||||
|
"Kurtosis",
|
||||||
|
]);
|
||||||
|
|
||||||
|
for stat in stats {
|
||||||
|
table.add_row(vec![
|
||||||
|
Cell::new(&stat.name),
|
||||||
|
Cell::new(stat.count.to_string()),
|
||||||
|
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
|
||||||
|
Cell::new(format_time(stat.mean)),
|
||||||
|
Cell::new(format_time(stat.median)),
|
||||||
|
Cell::new(format_time(stat.min)),
|
||||||
|
Cell::new(format_time(stat.max)),
|
||||||
|
Cell::new(format_time(stat.p95)),
|
||||||
|
Cell::new(format_time(stat.p99)),
|
||||||
|
Cell::new(format_time(stat.std_dev)),
|
||||||
|
Cell::new(format!("{:.3}", stat.cv)),
|
||||||
|
Cell::new(format!("{:.2}", stat.skewness)),
|
||||||
|
Cell::new(format!("{:.2}", stat.kurtosis)),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{table}");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn generate_interactive_report(
|
||||||
|
timings: &HashMap<String, Vec<(u128, u128)>>,
|
||||||
|
instructions_to_plot: &[&str], // <-- NEW: Only plot these
|
||||||
|
file_path: &str,
|
||||||
|
) {
|
||||||
|
|
||||||
|
use plotly::common::{Mode, Marker, Line};
|
||||||
|
use plotly::layout::{Axis, Layout};
|
||||||
|
use plotly::{Plot, Scatter};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
let mut plot = Plot::new();
|
||||||
|
|
||||||
|
for &instruction_name in instructions_to_plot {
|
||||||
|
// Only proceed if the instruction exists in our timings map
|
||||||
|
if let Some(times) = timings.get(instruction_name) {
|
||||||
|
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
|
||||||
|
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
|
||||||
|
|
||||||
|
let text_array: Vec<String> = times.iter()
|
||||||
|
.map(|&(_, dur)| format_time(dur as f64))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let trace = Scatter::new(x_axis, y_axis)
|
||||||
|
.name(instruction_name)
|
||||||
|
.mode(Mode::LinesMarkers)
|
||||||
|
.marker(Marker::new().size(4).opacity(0.6))
|
||||||
|
.line(Line::new().width(1.0))
|
||||||
|
.text_array(text_array)
|
||||||
|
.hover_info(plotly::common::HoverInfo::All);
|
||||||
|
|
||||||
|
plot.add_trace(trace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let layout = Layout::new()
|
||||||
|
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
|
||||||
|
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
|
||||||
|
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
|
||||||
|
|
||||||
|
plot.set_layout(layout);
|
||||||
|
plot.write_html(file_path);
|
||||||
|
println!("🌐 Interactive timeline saved to {}", file_path);
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,364 @@
|
|||||||
|
use crate::{
|
||||||
|
cpu::CPU,
|
||||||
|
instruction_set::instruction_data::InstructionData,
|
||||||
|
memory_manager::{
|
||||||
|
MemoryStorable,
|
||||||
|
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||||
|
},
|
||||||
|
tracing::Trace,
|
||||||
|
utility::{add_offset_r1, add_offset_rd},
|
||||||
|
};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
impl Trace {
|
||||||
|
///////////////////////////////////////////////////////////////
|
||||||
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
|
///////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
|
let core_indx = core_indx as usize;
|
||||||
|
if self.core_start_time.get(&core_indx).unwrap().is_none() {
|
||||||
|
self.core_start_time.insert(core_indx, Some(Instant::now()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
|
||||||
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
|
let core_indx = core_indx as usize;
|
||||||
|
let Self {
|
||||||
|
instruction_times,
|
||||||
|
core_start_time,
|
||||||
|
start_time,
|
||||||
|
} = self;
|
||||||
|
let now = Instant::now();
|
||||||
|
instruction_times
|
||||||
|
.get_mut(name)
|
||||||
|
.unwrap()
|
||||||
|
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
|
||||||
|
self.core_start_time.insert(core_indx, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sldi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sld");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sadd");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "ssub");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "smul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "saddi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "smuli");
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
///////////////////Matrix/vector Instructions////////////////////
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "setbw");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
|
[M]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
|
[M]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "mvmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvadd");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvsub");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvdmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvmax");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vavg");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vrelu");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vtanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vsigm");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vsoftmax");
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
/////Communication/synchronization Instructions/////////////////
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "ld");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "st");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "lldi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "lmv");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "send");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "recv");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use std::{fs::File, path::PathBuf};
|
||||||
|
|
||||||
|
pub mod pretty_print;
|
||||||
|
pub mod tracing_isa;
|
||||||
|
|
||||||
|
pub struct Trace {
|
||||||
|
out_files: Vec<File>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl Trace {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
out_files: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
|
||||||
|
path.pop();
|
||||||
|
for i in 0..num_core {
|
||||||
|
path.push(format!("TraceCore{}", i));
|
||||||
|
let file = File::create(&path).expect("Can not create file");
|
||||||
|
self.out_files.push(file);
|
||||||
|
path.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
+1
-10
@@ -1,4 +1,4 @@
|
|||||||
use crate::tracing::pretty_print;
|
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -13,7 +13,6 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
|
||||||
impl Trace {
|
impl Trace {
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
@@ -284,7 +283,6 @@ impl Trace {
|
|||||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -358,8 +356,6 @@ impl Trace {
|
|||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
use crate::{tracing::pretty_print, utility::add_offset_r2};
|
|
||||||
|
|
||||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -990,8 +986,6 @@ impl Trace {
|
|||||||
/////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -1044,8 +1038,6 @@ impl Trace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -1138,7 +1130,6 @@ impl Trace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
add_pim_library(OMPimCommon
|
add_pim_library(OMPimCommon
|
||||||
PimCommon.cpp
|
IR/AddressAnalysis.cpp
|
||||||
|
IR/CoreBlockUtils.cpp
|
||||||
|
IR/EntryPointUtils.cpp
|
||||||
|
IR/ShapeUtils.cpp
|
||||||
|
IR/WeightUtils.cpp
|
||||||
|
Support/DebugDump.cpp
|
||||||
|
Support/Diagnostics.cpp
|
||||||
|
Support/FileSystemUtils.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,259 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (!moduleOp || !getGlobalOp)
|
||||||
|
return {};
|
||||||
|
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (!knowledge)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto iter = knowledge->aliases.find(value);
|
||||||
|
while (iter != knowledge->aliases.end()) {
|
||||||
|
value = iter->second;
|
||||||
|
iter = knowledge->aliases.find(value);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (mlir::isa<mlir::BlockArgument>(value))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
||||||
|
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||||
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||||
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||||
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (knowledge) {
|
||||||
|
auto iter = knowledge->indexValues.find(value);
|
||||||
|
if (iter != knowledge->indexValues.end())
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
||||||
|
if (constantOp) {
|
||||||
|
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||||
|
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||||
|
|
||||||
|
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs + *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs - *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs * *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
||||||
|
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||||
|
if (!integerAttr)
|
||||||
|
return mlir::failure();
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (mlir::isa<mlir::BlockArgument>(value))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||||
|
if (!tiedOperand)
|
||||||
|
return mlir::failure();
|
||||||
|
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||||
|
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value = yieldedValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||||
|
strides.reserve(subviewOp.getMixedStrides().size());
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||||
|
if (failed(resolvedOffset))
|
||||||
|
return mlir::failure();
|
||||||
|
offsets.push_back(*resolvedOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||||
|
if (failed(resolvedSize))
|
||||||
|
return mlir::failure();
|
||||||
|
sizes.push_back(*resolvedSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||||
|
if (failed(resolvedStride))
|
||||||
|
return mlir::failure();
|
||||||
|
strides.push_back(*resolvedStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||||
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||||
|
value = resolveAlias(castOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveIndexValueImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||||
|
return resolveContiguousAddressImpl(value, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveContiguousAddressImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Describes a value as a base addressable object plus a statically known
|
||||||
|
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
||||||
|
struct ResolvedContiguousAddress {
|
||||||
|
mlir::Value base;
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Records compile-time facts used when interpreting address arithmetic and
|
||||||
|
/// loop-carried aliases inside PIM regions.
|
||||||
|
struct StaticValueKnowledge {
|
||||||
|
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||||
|
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||||
|
|
||||||
|
StaticValueKnowledge() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
|
/// Resolves a value to contiguous backing storage when that storage can be
|
||||||
|
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Statically evaluates index-like SSA values, including simple integer
|
||||||
|
/// arithmetic and loop facts recorded in `knowledge`.
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||||
|
/// loop-carried memref/result.
|
||||||
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,745 @@
|
|||||||
|
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||||
|
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||||
|
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace compact_asm {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
enum class ListDelimiter {
|
||||||
|
Square,
|
||||||
|
Paren
|
||||||
|
};
|
||||||
|
|
||||||
|
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||||
|
if (delimiter == ListDelimiter::Square)
|
||||||
|
return parser.parseLSquare();
|
||||||
|
return parser.parseLParen();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||||
|
if (delimiter == ListDelimiter::Square)
|
||||||
|
return parser.parseOptionalRSquare();
|
||||||
|
return parser.parseOptionalRParen();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT>
|
||||||
|
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||||
|
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT>
|
||||||
|
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||||
|
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename EntryT, typename ParseEntryFn>
|
||||||
|
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<EntryT>& entries,
|
||||||
|
ParseEntryFn parseEntry) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
EntryT entry;
|
||||||
|
if (parseEntry(entry))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
entries.push_back(entry);
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<IntT> subgroup;
|
||||||
|
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(values, subgroup);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int64_t first = 0;
|
||||||
|
if (parser.parseInteger(first))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
int64_t last = 0;
|
||||||
|
if (parser.parseInteger(last) || last < first)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||||
|
|
||||||
|
int64_t step = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||||
|
if (parser.parseInteger(step) || step <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||||
|
}
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
if ((last - first) % step != 0) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"range end must be reachable from start using the given step");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t value = first; value <= last; value += step)
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
values.push_back(static_cast<IntT>(value));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
values.push_back(static_cast<IntT>(first));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RangeT, typename PrintEntryFn>
|
||||||
|
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||||
|
for (size_t index = 0; index < entries.size();) {
|
||||||
|
size_t runEnd = index + 1;
|
||||||
|
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||||
|
++runEnd;
|
||||||
|
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printEntry(entries[index]);
|
||||||
|
size_t runLength = runEnd - index;
|
||||||
|
if (runLength > 1)
|
||||||
|
printer << " x" << runLength;
|
||||||
|
index = runEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT, typename IntT>
|
||||||
|
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
|
||||||
|
struct FlatCompression {
|
||||||
|
enum class Kind {
|
||||||
|
Single,
|
||||||
|
EqualRun,
|
||||||
|
Progression
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = Kind::Single;
|
||||||
|
size_t covered = 1;
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
size_t progressionValueCount = 1;
|
||||||
|
int64_t step = 1;
|
||||||
|
IntT firstValue {};
|
||||||
|
IntT lastValue {};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto computeFlatCompression = [&](size_t start) {
|
||||||
|
FlatCompression compression;
|
||||||
|
compression.firstValue = values[start];
|
||||||
|
compression.lastValue = values[start];
|
||||||
|
|
||||||
|
auto findEqualRunEnd = [&](size_t runStart) {
|
||||||
|
size_t runEnd = runStart + 1;
|
||||||
|
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||||
|
++runEnd;
|
||||||
|
return runEnd;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t firstRunEnd = findEqualRunEnd(start);
|
||||||
|
compression.repeatCount = firstRunEnd - start;
|
||||||
|
size_t progressionEnd = firstRunEnd;
|
||||||
|
int64_t step = 0;
|
||||||
|
IntT lastValue = values[start];
|
||||||
|
|
||||||
|
if (firstRunEnd < values.size()) {
|
||||||
|
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||||
|
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||||
|
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||||
|
progressionEnd = secondRunEnd;
|
||||||
|
lastValue = values[firstRunEnd];
|
||||||
|
size_t currentRunStart = secondRunEnd;
|
||||||
|
while (currentRunStart < values.size()) {
|
||||||
|
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||||
|
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||||
|
break;
|
||||||
|
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||||
|
break;
|
||||||
|
lastValue = values[currentRunStart];
|
||||||
|
progressionEnd = currentRunEnd;
|
||||||
|
currentRunStart = currentRunEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
step = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compression.covered = 1;
|
||||||
|
if (progressionEnd > firstRunEnd) {
|
||||||
|
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||||
|
if (progressionValueCount >= 3) {
|
||||||
|
compression.kind = FlatCompression::Kind::Progression;
|
||||||
|
compression.covered = progressionEnd - start;
|
||||||
|
compression.progressionValueCount = progressionValueCount;
|
||||||
|
compression.step = step;
|
||||||
|
compression.lastValue = lastValue;
|
||||||
|
return compression;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compression.repeatCount > 1) {
|
||||||
|
compression.kind = FlatCompression::Kind::EqualRun;
|
||||||
|
compression.covered = compression.repeatCount;
|
||||||
|
return compression;
|
||||||
|
}
|
||||||
|
|
||||||
|
return compression;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto findRepeatedSublist = [&](size_t start) {
|
||||||
|
size_t bestLength = 0;
|
||||||
|
size_t bestRepeatCount = 1;
|
||||||
|
size_t remaining = values.size() - start;
|
||||||
|
|
||||||
|
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||||
|
while (start + (repeatCount + 1) * length <= values.size()
|
||||||
|
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||||
|
++repeatCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (repeatCount <= 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t covered = length * repeatCount;
|
||||||
|
size_t bestCovered = bestLength * bestRepeatCount;
|
||||||
|
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||||
|
bestLength = length;
|
||||||
|
bestRepeatCount = repeatCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::pair(bestLength, bestRepeatCount);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t index = 0; index < values.size();) {
|
||||||
|
if (index != 0)
|
||||||
|
stream << ", ";
|
||||||
|
|
||||||
|
FlatCompression flat = computeFlatCompression(index);
|
||||||
|
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||||
|
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||||
|
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||||
|
printOpenDelimiter(stream, ListDelimiter::Paren);
|
||||||
|
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
|
||||||
|
printCloseDelimiter(stream, ListDelimiter::Paren);
|
||||||
|
stream << " x" << sublistRepeatCount;
|
||||||
|
index += repeatedSublistCoverage;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (flat.kind) {
|
||||||
|
case FlatCompression::Kind::Progression:
|
||||||
|
stream << flat.firstValue << " to " << flat.lastValue;
|
||||||
|
if (flat.step != 1)
|
||||||
|
stream << " by " << flat.step;
|
||||||
|
if (flat.repeatCount > 1)
|
||||||
|
stream << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::EqualRun:
|
||||||
|
stream << flat.firstValue << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::Single:
|
||||||
|
stream << flat.firstValue;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT, typename IntT>
|
||||||
|
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(stream, delimiter);
|
||||||
|
printCompressedIntegerEntries(stream, values);
|
||||||
|
printCloseDelimiter(stream, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||||
|
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||||
|
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||||
|
for (size_t index = 0; index < values.size();) {
|
||||||
|
size_t equalRunEnd = index + 1;
|
||||||
|
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||||
|
++equalRunEnd;
|
||||||
|
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
if (equalRunEnd - index > 1) {
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
printer << " x" << (equalRunEnd - index);
|
||||||
|
index = equalRunEnd;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t rangeEnd = index + 1;
|
||||||
|
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||||
|
while (rangeEnd < values.size()) {
|
||||||
|
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||||
|
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||||
|
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||||
|
while (rangeEnd < values.size()) {
|
||||||
|
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||||
|
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||||
|
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
if (rangeEnd - index >= 3) {
|
||||||
|
printer << " to ";
|
||||||
|
printer.printOperand(values[rangeEnd - 1]);
|
||||||
|
}
|
||||||
|
else if (rangeEnd - index == 2) {
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(values[index + 1]);
|
||||||
|
}
|
||||||
|
index = rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||||
|
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printCompressedValueSequence(printer, values);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printCompressedTypeSequence(printer, types);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||||
|
Type firstType;
|
||||||
|
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||||
|
if (!firstTypeResult.has_value()) {
|
||||||
|
if (allowEmpty)
|
||||||
|
return success();
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||||
|
}
|
||||||
|
if (failed(*firstTypeResult))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto appendType = [&](Type type) -> ParseResult {
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
types.push_back(type);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (appendType(firstType))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
Type nextType;
|
||||||
|
if (parser.parseType(nextType) || appendType(nextType))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||||
|
OpAsmParser::UnresolvedOperand firstOperand,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
OpAsmParser::UnresolvedOperand lastOperand;
|
||||||
|
if (parser.parseOperand(lastOperand))
|
||||||
|
return failure();
|
||||||
|
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||||
|
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||||
|
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
operands.push_back(firstOperand);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
OpAsmParser::UnresolvedOperand firstOperand;
|
||||||
|
if (parser.parseOperand(firstOperand))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
|
||||||
|
return failure();
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||||
|
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||||
|
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (values.size() / tupleSize);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printType(types[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (types.size() / tupleSize);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleOperands.clear();
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
}
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<Type> tupleTypes;
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleTypes.clear();
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
}
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
Type type;
|
||||||
|
if (parser.parseType(type))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
types.push_back(type);
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||||
|
if (block.getNumArguments() == 0) {
|
||||||
|
printer << "() = ()";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block.getNumArguments() == 1) {
|
||||||
|
printer.printOperand(block.getArgument(0));
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||||
|
OpAsmParser::Argument firstArgument,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
OpAsmParser::Argument lastArgument;
|
||||||
|
if (parser.parseArgument(lastArgument))
|
||||||
|
return failure();
|
||||||
|
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||||
|
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||||
|
}
|
||||||
|
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||||
|
arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
arguments.push_back(firstArgument);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||||
|
argument.type = inputType;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
if (succeeded(parser.parseOptionalRParen())) {
|
||||||
|
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseRParen() || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
arguments.push_back(argument);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace compact_asm
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||||
|
return mlir::isa<mlir::arith::ConstantOp,
|
||||||
|
mlir::arith::AddIOp,
|
||||||
|
mlir::arith::SubIOp,
|
||||||
|
mlir::arith::MulIOp,
|
||||||
|
mlir::arith::DivUIOp,
|
||||||
|
mlir::arith::RemUIOp,
|
||||||
|
mlir::arith::IndexCastOp,
|
||||||
|
mlir::memref::AllocOp,
|
||||||
|
mlir::memref::SubViewOp,
|
||||||
|
mlir::memref::CastOp,
|
||||||
|
mlir::memref::CollapseShapeOp,
|
||||||
|
mlir::memref::ExpandShapeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
walkPimCoreBlock(mlir::Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
for (mlir::Operation& op : block) {
|
||||||
|
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||||
|
mlir::Block& loopBody = forOp.getRegion().front();
|
||||||
|
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||||
|
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||||
|
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||||
|
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||||
|
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||||
|
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||||
|
StaticValueKnowledge loopKnowledge = knowledge;
|
||||||
|
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||||
|
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||||
|
loopKnowledge.aliases[iterArg] = iterValue;
|
||||||
|
|
||||||
|
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||||
|
hasFailure = true;
|
||||||
|
|
||||||
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
|
||||||
|
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||||
|
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(callback(op, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
return mlir::success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns true for ops in a `pim.core` body that only participate in static
|
||||||
|
/// address or index computation and therefore do not emit PIM instructions.
|
||||||
|
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
|
||||||
|
/// their bounds are known and invoking `callback` only on instruction-emitting
|
||||||
|
/// operations.
|
||||||
|
mlir::LogicalResult
|
||||||
|
walkPimCoreBlock(mlir::Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
||||||
|
if (!moduleOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
||||||
|
if (entryPoints.size() > 1) {
|
||||||
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
if (!entryPoints.empty()) {
|
||||||
|
auto entryPointAttr =
|
||||||
|
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||||
|
if (!entryPointAttr) {
|
||||||
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||||
|
if (!entryFunc) {
|
||||||
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||||
|
<< entryPointAttr.getLeafReference().getValue();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
return entryFunc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
||||||
|
return mainGraphFunc;
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
||||||
|
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
||||||
|
if (!funcOp.isExternal())
|
||||||
|
nonExternalFuncs.push_back(funcOp);
|
||||||
|
if (nonExternalFuncs.size() == 1)
|
||||||
|
return nonExternalFuncs.front();
|
||||||
|
|
||||||
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Resolves the function the PIM pipeline should treat as its entry point.
|
||||||
|
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
|
||||||
|
/// non-external function if the module is otherwise unambiguous.
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
||||||
|
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
||||||
|
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
indices[dim] = linearIndex / stride;
|
||||||
|
linearIndex %= stride;
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||||
|
linearIndex += index * stride;
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
numElements *= dim;
|
||||||
|
return numElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides) {
|
||||||
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||||
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstNonZeroOffset = std::find_if(
|
||||||
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return offset != 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||||
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||||
|
if (size > dimension - offset)
|
||||||
|
return false;
|
||||||
|
++firstNonZeroOffset;
|
||||||
|
|
||||||
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != sizesAndShape.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
|
||||||
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||||
|
|
||||||
|
void markWeightAlways(mlir::Operation* op) {
|
||||||
|
assert(op && "expected valid op");
|
||||||
|
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
|
bool found = false;
|
||||||
|
parentOp.walk([&](mlir::Operation* op) {
|
||||||
|
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||||
|
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||||
|
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||||
|
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||||
|
});
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
|
auto weights = parentOp.getWeights();
|
||||||
|
llvm::SmallSet<unsigned, 8> visited;
|
||||||
|
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||||
|
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||||
|
callback(parentOp->getOpOperand(weightIndex));
|
||||||
|
};
|
||||||
|
|
||||||
|
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||||
|
mlir::Operation* user = use.getOwner();
|
||||||
|
unsigned operandIndex = use.getOperandNumber();
|
||||||
|
|
||||||
|
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
||||||
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||||
|
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
||||||
|
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
||||||
|
if (!visited.insert(currentValue).second)
|
||||||
|
return true;
|
||||||
|
if (currentValue.use_empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
||||||
|
if (isSpatialMvmVmmWeightUse(use))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
mlir::Operation* user = use.getOwner();
|
||||||
|
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
||||||
|
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||||
|
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
||||||
|
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||||
|
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||||
|
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||||
|
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
||||||
|
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return walkUses(value, walkUses);
|
||||||
|
}
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
|
assert(root && "expected valid root op");
|
||||||
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
|
auto weights = coreOp.getWeights();
|
||||||
|
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||||
|
if (weightIndex < weights.size())
|
||||||
|
callback(coreOp->getOpOperand(weightIndex));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
auto weights = coreBatchOp.getWeights();
|
||||||
|
for (auto weight : weights)
|
||||||
|
for (mlir::OpOperand& use : weight.getUses())
|
||||||
|
if (use.getOwner() == coreBatchOp.getOperation())
|
||||||
|
callback(use);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||||
|
/// weight across later PIM lowering/codegen stages.
|
||||||
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||||
|
|
||||||
|
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
|
||||||
|
/// allowing later passes to preserve it as a dedicated weight-like object.
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||||
|
|
||||||
|
/// Visits weight operands consumed by Pim core ops/core batches so downstream
|
||||||
|
/// passes can identify globals that must remain weight-backed.
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,546 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <filesystem>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
std::string getOutputDir() {
|
|
||||||
if (outputBaseName.empty() || outputBaseName == "-")
|
|
||||||
return {};
|
|
||||||
|
|
||||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
|
||||||
if (lastSlash == std::string::npos)
|
|
||||||
return ".";
|
|
||||||
return outputBaseName.substr(0, lastSlash);
|
|
||||||
}
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory) {
|
|
||||||
std::error_code errorCode;
|
|
||||||
std::filesystem::create_directories(directory, errorCode);
|
|
||||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
|
||||||
}
|
|
||||||
|
|
||||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
|
||||||
std::string outputDir = getOutputDir();
|
|
||||||
if (outputDir.empty())
|
|
||||||
return;
|
|
||||||
|
|
||||||
std::string dialectsDir = outputDir + "/dialects";
|
|
||||||
createDirectory(dialectsDir);
|
|
||||||
|
|
||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *moduleOp;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|
||||||
if (!moduleOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
|
||||||
if (entryPoints.size() > 1) {
|
|
||||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (!entryPoints.empty()) {
|
|
||||||
auto entryPointAttr =
|
|
||||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
|
||||||
if (!entryPointAttr) {
|
|
||||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
|
||||||
if (!entryFunc) {
|
|
||||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
|
||||||
<< entryPointAttr.getLeafReference().getValue();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return entryFunc;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
|
||||||
return mainGraphFunc;
|
|
||||||
|
|
||||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
|
||||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
|
||||||
if (!funcOp.isExternal())
|
|
||||||
nonExternalFuncs.push_back(funcOp);
|
|
||||||
if (nonExternalFuncs.size() == 1)
|
|
||||||
return nonExternalFuncs.front();
|
|
||||||
|
|
||||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
|
||||||
|
|
||||||
void markWeightAlways(Operation* op) {
|
|
||||||
assert(op && "expected valid op");
|
|
||||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
|
||||||
}
|
|
||||||
|
|
||||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
|
||||||
if (!moduleOp || !getGlobalOp)
|
|
||||||
return {};
|
|
||||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
// channelNewOp should have two users: `op` and a
|
|
||||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
|
||||||
auto channelUsers = channelNewOp->getUsers();
|
|
||||||
auto usersIterator = channelUsers.begin();
|
|
||||||
auto firstUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator == channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"only one found.");
|
|
||||||
channelNewOp->dump();
|
|
||||||
op->dump();
|
|
||||||
channelNewOp->getParentOp()->dump();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto secondUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator != channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"more than two found.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
Operation* notOpUser;
|
|
||||||
if (firstUser == op) {
|
|
||||||
notOpUser = secondUser;
|
|
||||||
}
|
|
||||||
else if (secondUser == op) {
|
|
||||||
notOpUser = firstUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"and one of them must be me, but"
|
|
||||||
"none of them is actually me.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opIsReceive) {
|
|
||||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelSendOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelReceiveOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
||||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
|
||||||
SmallVector<int64_t> indices(shape.size(), 0);
|
|
||||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
||||||
indices[dim] = linearIndex / stride;
|
|
||||||
linearIndex %= stride;
|
|
||||||
}
|
|
||||||
return indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
|
||||||
int64_t linearIndex = 0;
|
|
||||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
||||||
linearIndex += index * stride;
|
|
||||||
return linearIndex;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
|
||||||
int64_t numElements = 1;
|
|
||||||
for (int64_t dim : shape)
|
|
||||||
numElements *= dim;
|
|
||||||
return numElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
|
||||||
ArrayRef<int64_t> offsets,
|
|
||||||
ArrayRef<int64_t> sizes,
|
|
||||||
ArrayRef<int64_t> strides) {
|
|
||||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
||||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstNonZeroOffset = std::find_if(
|
|
||||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return offset != 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
||||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
||||||
if (size > dimension - offset)
|
|
||||||
return false;
|
|
||||||
++firstNonZeroOffset;
|
|
||||||
|
|
||||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstDifferentSize != sizesAndShape.end()) {
|
|
||||||
++firstDifferentSize;
|
|
||||||
|
|
||||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, _dimension] = sizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (!knowledge)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
auto iter = knowledge->aliases.find(value);
|
|
||||||
while (iter != knowledge->aliases.end()) {
|
|
||||||
value = iter->second;
|
|
||||||
iter = knowledge->aliases.find(value);
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
|
||||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
|
||||||
// and when propagating yielded values across iterations during static unrolling.
|
|
||||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
|
||||||
return value;
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
if (auto result = dyn_cast<OpResult>(value))
|
|
||||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
|
||||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
|
||||||
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (knowledge) {
|
|
||||||
auto iter = knowledge->indexValues.find(value);
|
|
||||||
if (iter != knowledge->indexValues.end())
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
|
||||||
if (constantOp) {
|
|
||||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
|
||||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
|
||||||
|
|
||||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs + *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs - *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs * *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
|
||||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
|
||||||
if (!integerAttr)
|
|
||||||
return failure();
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
|
||||||
const StaticValueKnowledge* knowledge) {
|
|
||||||
int64_t byteOffset = 0;
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (isa<BlockArgument>(value))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
|
||||||
if (!tiedOperand)
|
|
||||||
return failure();
|
|
||||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
|
||||||
auto result = dyn_cast<OpResult>(value);
|
|
||||||
if (!result)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
|
||||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
|
||||||
// (every iteration yields the same backing memory in the DPS sense).
|
|
||||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
||||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
|
||||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
||||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value = yieldedValue;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
|
||||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> offsets;
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
SmallVector<int64_t> strides;
|
|
||||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
||||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
|
||||||
strides.reserve(subviewOp.getMixedStrides().size());
|
|
||||||
|
|
||||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
||||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
|
||||||
if (failed(resolvedOffset))
|
|
||||||
return failure();
|
|
||||||
offsets.push_back(*resolvedOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
|
||||||
if (failed(resolvedSize))
|
|
||||||
return failure();
|
|
||||||
sizes.push_back(*resolvedSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
|
||||||
if (failed(resolvedStride))
|
|
||||||
return failure();
|
|
||||||
strides.push_back(*resolvedStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
|
||||||
value = resolveAlias(castOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
|
||||||
|
|
||||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveIndexValueImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
|
||||||
return resolveContiguousAddressImpl(value, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveContiguousAddressImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isCoreStaticAddressOp(Operation* op) {
|
|
||||||
return isa<arith::ConstantOp,
|
|
||||||
arith::AddIOp,
|
|
||||||
arith::SubIOp,
|
|
||||||
arith::MulIOp,
|
|
||||||
arith::DivUIOp,
|
|
||||||
arith::RemUIOp,
|
|
||||||
arith::IndexCastOp,
|
|
||||||
memref::AllocOp,
|
|
||||||
memref::SubViewOp,
|
|
||||||
memref::CastOp,
|
|
||||||
memref::CollapseShapeOp,
|
|
||||||
memref::ExpandShapeOp>(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult walkPimCoreBlock(Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
|
||||||
bool hasFailure = false;
|
|
||||||
for (Operation& op : block) {
|
|
||||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
|
||||||
Block& loopBody = forOp.getRegion().front();
|
|
||||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
|
||||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
|
||||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
|
||||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
|
||||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
|
||||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
|
||||||
StaticValueKnowledge loopKnowledge = knowledge;
|
|
||||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
|
||||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
|
||||||
loopKnowledge.aliases[iterArg] = iterValue;
|
|
||||||
|
|
||||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
|
||||||
hasFailure = true;
|
|
||||||
|
|
||||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
|
||||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
|
||||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(callback(op, knowledge)))
|
|
||||||
hasFailure = true;
|
|
||||||
}
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -7,82 +7,22 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct ResolvedContiguousAddress {
|
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
|
||||||
mlir::Value base;
|
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
|
||||||
int64_t byteOffset = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct StaticValueKnowledge {
|
|
||||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
|
||||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
|
||||||
|
|
||||||
StaticValueKnowledge() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string getOutputDir();
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory);
|
|
||||||
|
|
||||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
|
||||||
|
|
||||||
bool hasWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
void markWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation*>
|
|
||||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t>
|
|
||||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
|
||||||
llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
||||||
const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
|
||||||
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
|
||||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
|
||||||
/// only contribute to static addressing or index computations (arith integer math,
|
|
||||||
/// memref view ops, memref.alloc, arith.constant).
|
|
||||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
|
||||||
|
|
||||||
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
|
||||||
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
|
||||||
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
|
||||||
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
|
||||||
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
|
||||||
mlir::LogicalResult
|
|
||||||
walkPimCoreBlock(mlir::Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/dialects";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
|
llvm::raw_os_ostream os(file);
|
||||||
|
mlir::OpPrintingFlags flags;
|
||||||
|
flags.elideLargeElementsAttrs();
|
||||||
|
moduleOp.print(os, flags);
|
||||||
|
os.flush();
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Emits a MLIR snapshot under the current compiler output
|
||||||
|
/// directory for pass-level debugging.
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
|
||||||
|
return op->emitOpError() << "requires statically shaped " << valueDescription;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||||
|
llvm::StringRef valueDescription,
|
||||||
|
int64_t actualRank,
|
||||||
|
llvm::ArrayRef<int64_t> supportedRanks) {
|
||||||
|
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
|
||||||
|
if (supportedRanks.empty())
|
||||||
|
return diag;
|
||||||
|
|
||||||
|
diag << "; supported rank";
|
||||||
|
if (supportedRanks.size() != 1)
|
||||||
|
diag << 's';
|
||||||
|
diag << ' ';
|
||||||
|
|
||||||
|
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic
|
||||||
|
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
|
||||||
|
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
|
||||||
|
llvm::StringRef action,
|
||||||
|
llvm::StringRef path,
|
||||||
|
const std::error_code& errorCode) {
|
||||||
|
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <system_error>
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
|
||||||
|
/// accepted by the current lowering/codegen path.
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||||
|
llvm::StringRef valueDescription,
|
||||||
|
int64_t actualRank,
|
||||||
|
llvm::ArrayRef<int64_t> supportedRanks);
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for missing symbol/global references.
|
||||||
|
mlir::InFlightDiagnostic
|
||||||
|
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
|
||||||
|
|
||||||
|
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
|
||||||
|
/// the relevant IR location.
|
||||||
|
mlir::LogicalResult
|
||||||
|
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
|
||||||
|
return mlir::success(succeeded(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::string getOutputDir() {
|
||||||
|
if (outputBaseName.empty() || outputBaseName == "-")
|
||||||
|
return {};
|
||||||
|
|
||||||
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||||
|
if (lastSlash == std::string::npos)
|
||||||
|
return ".";
|
||||||
|
return outputBaseName.substr(0, lastSlash);
|
||||||
|
}
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::filesystem::create_directories(directory, errorCode);
|
||||||
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns the directory that should hold PIM artifacts/debug dumps for the
|
||||||
|
/// current compiler invocation.
|
||||||
|
std::string getOutputDir();
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
|
|||||||
|
|
||||||
add_pim_library(OMPimCompilerUtils
|
add_pim_library(OMPimCompilerUtils
|
||||||
PimCompilerUtils.cpp
|
PimCompilerUtils.cpp
|
||||||
|
PimArtifactWriter.cpp
|
||||||
|
PimBatchEmission.cpp
|
||||||
PimCodeGen.cpp
|
PimCodeGen.cpp
|
||||||
|
PimWeightEmitter.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,123 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
|
||||||
|
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
|
||||||
|
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
||||||
|
hostFileStream.close();
|
||||||
|
return CompilerSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes
|
||||||
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
|
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||||
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (hasWeightAlways(getGlobalOp))
|
||||||
|
return;
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp)
|
||||||
|
return;
|
||||||
|
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||||
|
return;
|
||||||
|
auto initialValue = globalOp.getInitialValue();
|
||||||
|
if (!initialValue)
|
||||||
|
return;
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||||
|
if (!denseAttr)
|
||||||
|
return;
|
||||||
|
|
||||||
|
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
||||||
|
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||||
|
char* dst = memoryBuffer.data() + memEntry.address;
|
||||||
|
|
||||||
|
if (denseAttr.isSplat()) {
|
||||||
|
size_t elementSize = rawData.size();
|
||||||
|
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
||||||
|
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
||||||
|
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
||||||
|
std::memcpy(dst, rawData.data(), rawData.size());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
||||||
|
memoryFileStream.close();
|
||||||
|
return CompilerSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
size_t maxCoreId,
|
||||||
|
json::Object xbarsPerArrayGroup,
|
||||||
|
StringRef outputDirPath) {
|
||||||
|
json::Object configJson;
|
||||||
|
|
||||||
|
configJson["core_cnt"] = maxCoreId + 1;
|
||||||
|
configJson["adc_count"] = 16;
|
||||||
|
configJson["cell_precision"] = 2;
|
||||||
|
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
||||||
|
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||||
|
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||||
|
|
||||||
|
json::Array inputsAddresses;
|
||||||
|
for (BlockArgument input : funcOp.getArguments())
|
||||||
|
inputsAddresses.push_back(memory.getValueAddress(input));
|
||||||
|
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
||||||
|
|
||||||
|
json::Array outputsAddresses;
|
||||||
|
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||||
|
for (mlir::Value output : returnOp.getOperands())
|
||||||
|
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||||
|
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||||
|
|
||||||
|
auto configPath = (outputDirPath + "/config.json").str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream jsonOS(configPath, errorCode);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
jsonOS << json::Value(std::move(configJson)) << '\n';
|
||||||
|
jsonOS.close();
|
||||||
|
return CompilerSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
class PimAcceleratorMemory;
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
|
||||||
|
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||||
|
mlir::func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
llvm::StringRef outputDirPath);
|
||||||
|
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
size_t maxCoreId,
|
||||||
|
llvm::json::Object xbarsPerArrayGroup,
|
||||||
|
llvm::StringRef outputDirPath);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||||
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||||
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||||
|
SmallVector<int32_t> laneCoreIds;
|
||||||
|
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||||
|
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||||
|
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||||
|
return laneCoreIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||||
|
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||||
|
OpBuilder builder(scratchModule->getContext());
|
||||||
|
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||||
|
|
||||||
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||||
|
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||||
|
SmallVector<Value> laneWeights;
|
||||||
|
laneWeights.reserve(weightsPerLane);
|
||||||
|
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||||
|
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||||
|
|
||||||
|
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
auto scalarCore = pim::PimCoreOp::create(
|
||||||
|
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||||
|
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||||
|
IRMapping mapper;
|
||||||
|
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||||
|
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||||
|
|
||||||
|
builder.setInsertionPointToEnd(block);
|
||||||
|
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||||
|
if (isa<pim::PimHaltOp>(op)) {
|
||||||
|
pim::PimHaltOp::create(builder, op.getLoc());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||||
|
pim::PimSendOp::create(builder,
|
||||||
|
sendBatchOp.getLoc(),
|
||||||
|
mapper.lookup(sendBatchOp.getInput()),
|
||||||
|
sendBatchOp.getSizeAttr(),
|
||||||
|
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||||
|
pim::PimSendTensorOp::create(
|
||||||
|
builder,
|
||||||
|
sendTensorBatchOp.getLoc(),
|
||||||
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||||
|
auto scalarReceive =
|
||||||
|
pim::PimReceiveOp::create(builder,
|
||||||
|
receiveBatchOp.getLoc(),
|
||||||
|
receiveBatchOp.getOutput().getType(),
|
||||||
|
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||||
|
receiveBatchOp.getSizeAttr(),
|
||||||
|
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||||
|
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||||
|
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||||
|
builder,
|
||||||
|
receiveTensorBatchOp.getLoc(),
|
||||||
|
receiveTensorBatchOp.getOutput().getType(),
|
||||||
|
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||||
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||||
|
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||||
|
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||||
|
if (!hostSource)
|
||||||
|
hostSource = memcpBatchOp.getHostSource();
|
||||||
|
|
||||||
|
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
||||||
|
memcpBatchOp.getLoc(),
|
||||||
|
memcpBatchOp.getOutput().getType(),
|
||||||
|
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||||
|
hostSource,
|
||||||
|
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||||
|
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||||
|
memcpBatchOp.getSizeAttr());
|
||||||
|
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = builder.clone(op, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||||
|
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||||
|
return callback(scalarCore);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+358
-301
@@ -1,30 +1,48 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/AsmState.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/IR/Verifier.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
#include "llvm/Support/Format.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <absl/types/compare.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cstdint>
|
||||||
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
#include "Common/IR/CompactAsmUtils.hpp"
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
using namespace onnx_mlir::compact_asm;
|
||||||
|
|
||||||
|
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||||
|
auto type = cast<ShapedType>(value.getType());
|
||||||
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
|
}
|
||||||
|
|
||||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||||
auto type = cast<ShapedType>(value.getType());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
@@ -53,9 +71,22 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
|
SmallVector<mlir::Value> args;
|
||||||
|
|
||||||
|
for (mlir::Value arg : funcOp.getArguments()) {
|
||||||
|
gatherMemEntry(arg);
|
||||||
|
args.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (globalMemrefOp.getName().starts_with("arg")) {
|
||||||
|
StringRef indexStr = globalMemrefOp.getName().substr(4);
|
||||||
|
int index = 0;
|
||||||
|
llvm::to_integer(indexStr, index, 10);
|
||||||
|
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
|
||||||
|
}
|
||||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (inserted)
|
if (inserted)
|
||||||
gatherMemEntry(getGlobalOp.getResult());
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
@@ -64,9 +95,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments())
|
|
||||||
gatherMemEntry(arg);
|
|
||||||
|
|
||||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
gatherMemEntry(allocOp.getResult());
|
gatherMemEntry(allocOp.getResult());
|
||||||
@@ -84,6 +112,60 @@ void PimMemory::allocateCore(Operation* op) {
|
|||||||
allocateGatheredMemory();
|
allocateGatheredMemory();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string formatMemory(uint64_t bytes) {
|
||||||
|
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||||
|
int i = 0;
|
||||||
|
double size = static_cast<double>(bytes);
|
||||||
|
while (size >= 1024 && i < 6) {
|
||||||
|
size /= 1024;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
// Formats to 2 decimal places
|
||||||
|
std::string out;
|
||||||
|
llvm::raw_string_ostream rss(out);
|
||||||
|
rss << llvm::format("%.2f ", size) << units[i];
|
||||||
|
return rss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||||
|
os << "\tNumber of allocas: " << row.numAlloca << "\n";
|
||||||
|
os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n";
|
||||||
|
os << "\tNumber of globals: " << row.numGlobal << "\n";
|
||||||
|
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
||||||
|
MemoryReportRow result = lhs;
|
||||||
|
result.numAlloca += rhs.numAlloca;
|
||||||
|
result.sizeAlloca += rhs.sizeAlloca;
|
||||||
|
result.numGlobal += rhs.numGlobal;
|
||||||
|
result.sizeGlobal += rhs.sizeGlobal;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MemoryReportRow PimMemory::getReportRow() const {
|
||||||
|
MemoryReportRow row;
|
||||||
|
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
||||||
|
if (auto op = val.getDefiningOp()) {
|
||||||
|
if (isa<memref::AllocOp>(op)) {
|
||||||
|
row.numAlloca++;
|
||||||
|
row.sizeAlloca += memEntry.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<memref::GetGlobalOp>(op)) {
|
||||||
|
row.numGlobal++;
|
||||||
|
row.sizeGlobal += memEntry.size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return row;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimMemory::remove(mlir::Value val) {
|
||||||
|
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||||
|
globalMemEntriesMap.erase(removeIter);
|
||||||
|
}
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||||
auto iter = globalMemEntriesMap.find(value);
|
auto iter = globalMemEntriesMap.find(value);
|
||||||
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
||||||
@@ -124,6 +206,99 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
|||||||
return iter->second.address + resolvedAddress->byteOffset;
|
return iter->second.address + resolvedAddress->byteOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::reportHost() {
|
||||||
|
hostReportRow = hostMem.getReportRow();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||||
|
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
|
||||||
|
MemoryReportEntry entry;
|
||||||
|
entry.kind = MemoryReportEntry::Kind::Batch;
|
||||||
|
entry.id = batchId;
|
||||||
|
llvm::append_range(entry.coreIds, coreIds);
|
||||||
|
entry.row = row;
|
||||||
|
reportEntries.push_back(std::move(entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::flushReport() {
|
||||||
|
if (!fileReport.is_open())
|
||||||
|
return;
|
||||||
|
|
||||||
|
llvm::raw_os_ostream os(fileReport);
|
||||||
|
if (hostReportRow.has_value()) {
|
||||||
|
os << "Host:\n";
|
||||||
|
printMemoryReportRow(os, *hostReportRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!reportEntries.empty()) {
|
||||||
|
if (hostReportRow.has_value())
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
|
||||||
|
if (lhs.kind != rhs.kind)
|
||||||
|
return lhs.kind == MemoryReportEntry::Kind::Batch;
|
||||||
|
|
||||||
|
const MemoryReportRow& lhsRow = lhs.row;
|
||||||
|
const MemoryReportRow& rhsRow = rhs.row;
|
||||||
|
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
||||||
|
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
||||||
|
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
||||||
|
return lhsRow.numAlloca > rhsRow.numAlloca;
|
||||||
|
if (lhsRow.sizeGlobal != rhsRow.sizeGlobal)
|
||||||
|
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
||||||
|
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
||||||
|
return lhsRow.numGlobal > rhsRow.numGlobal;
|
||||||
|
return lhs.id < rhs.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (size_t index = 0; index < reportEntries.size();) {
|
||||||
|
size_t runEnd = index + 1;
|
||||||
|
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
||||||
|
&& reportEntries[runEnd].row == reportEntries[index].row) {
|
||||||
|
++runEnd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
|
||||||
|
os << "Batch ";
|
||||||
|
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
||||||
|
if (batchIndex != index)
|
||||||
|
os << ",\n ";
|
||||||
|
os << reportEntries[batchIndex].id << " (cores ";
|
||||||
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
|
||||||
|
os << ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
llvm::SmallVector<int32_t, 8> coreIds;
|
||||||
|
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||||
|
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
|
||||||
|
os << "Core ";
|
||||||
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||||
|
}
|
||||||
|
os << ":\n";
|
||||||
|
printMemoryReportRow(os, reportEntries[index].row);
|
||||||
|
if (runEnd < reportEntries.size())
|
||||||
|
os << "\n";
|
||||||
|
|
||||||
|
index = runEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
os.flush();
|
||||||
|
fileReport.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::clean(mlir::Operation* op) {
|
||||||
|
for (auto value : op->getResults()) {
|
||||||
|
hostMem.remove(value);
|
||||||
|
for (auto& device : deviceMem)
|
||||||
|
device.second.remove(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
json::Object PimCodeGen::createEmptyOffset() {
|
json::Object PimCodeGen::createEmptyOffset() {
|
||||||
json::Object offset;
|
json::Object offset;
|
||||||
offset["offset_select"] = 0;
|
offset["offset_select"] = 0;
|
||||||
@@ -131,6 +306,12 @@ json::Object PimCodeGen::createEmptyOffset() {
|
|||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
||||||
|
auto it = emittedCoreIds.find(coreId);
|
||||||
|
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
static json::Object createRs1OnlyOffset() {
|
static json::Object createRs1OnlyOffset() {
|
||||||
json::Object offset;
|
json::Object offset;
|
||||||
offset["offset_select"] = 1;
|
offset["offset_select"] = 1;
|
||||||
@@ -190,7 +371,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
|
|||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = opName;
|
json["op"] = opName;
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["core"] = coreId;
|
json["core"] = remapCoreId(coreId);
|
||||||
json["size"] = size;
|
json["size"] = size;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
@@ -242,10 +423,62 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
|||||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||||
|
const StaticValueKnowledge& knowledge) const {
|
||||||
|
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||||
|
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
||||||
|
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||||
|
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||||
|
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
||||||
|
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||||
|
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
||||||
|
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
||||||
|
|
||||||
|
int64_t axis = concatOp.getAxis();
|
||||||
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
|
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
||||||
|
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||||
|
|
||||||
|
size_t outerCount = 1;
|
||||||
|
for (int64_t dim = 0; dim < axis; ++dim)
|
||||||
|
outerCount *= static_cast<size_t>(outputShape[dim]);
|
||||||
|
|
||||||
|
size_t innerCount = 1;
|
||||||
|
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
|
||||||
|
innerCount *= static_cast<size_t>(outputShape[dim]);
|
||||||
|
|
||||||
|
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
|
||||||
|
size_t concatOffset = 0;
|
||||||
|
for (mlir::Value input : concatOp.getInputs()) {
|
||||||
|
auto inputType = cast<ShapedType>(input.getType());
|
||||||
|
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
|
||||||
|
|
||||||
|
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
|
||||||
|
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
|
||||||
|
size_t inputAddr = addressOf(input, knowledge);
|
||||||
|
|
||||||
|
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
|
||||||
|
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
|
||||||
|
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
|
||||||
|
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
|
||||||
|
}
|
||||||
|
|
||||||
|
concatOffset += inputConcatDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
||||||
MVMTy mvmLikeOp,
|
MVMTy mvmLikeOp,
|
||||||
@@ -256,11 +489,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
|||||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
|
||||||
auto type = cast<ShapedType>(value.getType());
|
|
||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
|
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
|
||||||
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
|
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
|
||||||
@@ -412,6 +640,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
|||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
|
||||||
|
|
||||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
||||||
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
||||||
@@ -474,67 +704,59 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
return std::to_string(size) + " Bytes";
|
return std::to_string(size) + " Bytes";
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||||
SmallVector<unsigned, 8> indices;
|
SmallVector<unsigned, 8> indices;
|
||||||
auto addIndex = [&](unsigned weightIndex) {
|
auto addIndex = [&](unsigned weightIndex) {
|
||||||
if (!llvm::is_contained(indices, weightIndex))
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
};
|
};
|
||||||
|
|
||||||
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
static OnnxMlirCompilerErrorCodes
|
return getUsedWeightIndices(coreOp.getBody().front());
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
|
||||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
|
||||||
std::error_code errorCode;
|
|
||||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
if (hasWeightAlways(getGlobalOp))
|
|
||||||
return;
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
||||||
if (!globalOp)
|
|
||||||
return;
|
|
||||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
|
||||||
return;
|
|
||||||
auto initialValue = globalOp.getInitialValue();
|
|
||||||
if (!initialValue)
|
|
||||||
return;
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
|
||||||
if (!denseAttr)
|
|
||||||
return;
|
|
||||||
|
|
||||||
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
|
||||||
ArrayRef<char> rawData = denseAttr.getRawData();
|
|
||||||
char* dst = memoryBuffer.data() + memEntry.address;
|
|
||||||
|
|
||||||
if (denseAttr.isSplat()) {
|
|
||||||
size_t elementSize = rawData.size();
|
|
||||||
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
|
||||||
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
|
||||||
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||||
std::memcpy(dst, rawData.data(), rawData.size());
|
SmallVector<Operation*> coreLikeOps;
|
||||||
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||||
|
coreLikeOps.push_back(&op);
|
||||||
|
return coreLikeOps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||||
|
func::FuncOp funcOp,
|
||||||
|
pim::PimCoreOp coreOp,
|
||||||
|
PimAcceleratorMemory& memory) {
|
||||||
|
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!targetGlobal)
|
||||||
|
return;
|
||||||
|
|
||||||
|
mlir::Value aliasedValue;
|
||||||
|
funcOp.walk([&](memref::GetGlobalOp candidate) {
|
||||||
|
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
|
||||||
|
return;
|
||||||
|
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
|
||||||
|
aliasedValue = candidate.getResult();
|
||||||
});
|
});
|
||||||
|
|
||||||
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
if (aliasedValue)
|
||||||
memoryFileStream.close();
|
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
|
||||||
return CompilerSuccess;
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dispatch all operations in a core region to the appropriate code generator.
|
/// Dispatch all operations in a core region to the appropriate code generator.
|
||||||
@@ -553,12 +775,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
|
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
|
||||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||||
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
||||||
|
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
|
||||||
|
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
|
||||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||||
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
||||||
|
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
|
||||||
|
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||||
|
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||||
|
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
|
|
||||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
@@ -581,9 +807,10 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
||||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
||||||
|
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||||
|
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
processedOperations++;
|
processedOperations++;
|
||||||
@@ -592,225 +819,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write crossbar weight matrices as padded binary files for a single core.
|
|
||||||
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
|
||||||
pim::PimCoreOp coreOp,
|
|
||||||
StringRef coreWeightsDirPath,
|
|
||||||
json::Array& xbarsPerGroup) {
|
|
||||||
int64_t xbarSize = crossbarSize.getValue();
|
|
||||||
std::error_code errorCode;
|
|
||||||
size_t weightIndex = 0;
|
|
||||||
|
|
||||||
for (auto weight : coreOp.getWeights()) {
|
|
||||||
xbarsPerGroup.push_back(weightIndex);
|
|
||||||
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
|
||||||
if (!getGlobalOp) {
|
|
||||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
|
|
||||||
weightIndex++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
||||||
if (!globalOp) {
|
|
||||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
|
|
||||||
weightIndex++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto initialValue = globalOp.getInitialValue();
|
|
||||||
if (!initialValue) {
|
|
||||||
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
|
|
||||||
weightIndex++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
|
||||||
if (!denseAttr) {
|
|
||||||
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
|
|
||||||
weightIndex++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto type = denseAttr.getType();
|
|
||||||
auto shape = type.getShape();
|
|
||||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
|
||||||
int64_t numRows = shape[0];
|
|
||||||
int64_t numCols = shape[1];
|
|
||||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
|
||||||
|
|
||||||
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
|
|
||||||
|
|
||||||
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
|
|
||||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t zero = 0;
|
|
||||||
for (int64_t row = 0; row < xbarSize; row++) {
|
|
||||||
for (int64_t col = 0; col < xbarSize; col++) {
|
|
||||||
if (row < numRows && col < numCols) {
|
|
||||||
int64_t index = row * numCols + col;
|
|
||||||
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
|
|
||||||
uint64_t word = bits.getZExtValue();
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
weightFileStream.close();
|
|
||||||
weightIndex++;
|
|
||||||
}
|
|
||||||
|
|
||||||
return CompilerSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
|
|
||||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|
||||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
|
||||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
|
||||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
|
||||||
assert(!error && "Error creating weights directory");
|
|
||||||
size_t indexFileName = 0;
|
|
||||||
|
|
||||||
int64_t xbarSize = crossbarSize.getValue();
|
|
||||||
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
|
||||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
|
||||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
|
||||||
if (index >= coreOp.getWeights().size()) {
|
|
||||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
|
||||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
|
||||||
}
|
|
||||||
mlir::Value weight = coreOp.getWeights()[index];
|
|
||||||
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
|
||||||
if (!getGlobalOp) {
|
|
||||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
|
||||||
assert(!getGlobalOp && "Weight is not from a memref.get_global");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
||||||
if (!globalOp) {
|
|
||||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
|
|
||||||
assert(!globalOp && "Could not find memref.global");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto initialValue = globalOp.getInitialValue();
|
|
||||||
if (!initialValue) {
|
|
||||||
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
|
|
||||||
assert(!initialValue && "memref.global has no initial value");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
|
||||||
if (!denseAttr) {
|
|
||||||
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
|
|
||||||
assert(!denseAttr && "memref.global initial value is not dense");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mapGlobalOpToFileName.contains(globalOp)) {
|
|
||||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
|
||||||
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
|
|
||||||
mapCoreWeightToFileName[coreOp].insert(weightToFile);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto type = denseAttr.getType();
|
|
||||||
auto shape = type.getShape();
|
|
||||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
|
||||||
int64_t numRows = shape[0];
|
|
||||||
int64_t numCols = shape[1];
|
|
||||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
|
||||||
|
|
||||||
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
|
|
||||||
|
|
||||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
|
||||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
|
||||||
std::error_code errorCode;
|
|
||||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
|
||||||
assert(errorCode);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t zero = 0;
|
|
||||||
for (int64_t row = 0; row < xbarSize; row++) {
|
|
||||||
for (int64_t col = 0; col < xbarSize; col++) {
|
|
||||||
if (row < numRows && col < numCols) {
|
|
||||||
int64_t index = row * numCols + col;
|
|
||||||
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
|
|
||||||
uint64_t word = bits.getZExtValue();
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
weightFileStream.close();
|
|
||||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
|
||||||
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mapCoreWeightToFileName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
|
||||||
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
|
||||||
PimAcceleratorMemory& memory,
|
|
||||||
size_t coreCount,
|
|
||||||
json::Object xbarsPerArrayGroup,
|
|
||||||
StringRef outputDirPath) {
|
|
||||||
json::Object configJson;
|
|
||||||
|
|
||||||
// +1 because pimsim-nn also considers the host as a core
|
|
||||||
configJson["core_cnt"] = coreCount + 1;
|
|
||||||
|
|
||||||
// TODO: Should this be based on the floating point type used in the model?
|
|
||||||
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
|
||||||
|
|
||||||
// Number of ADC for MVM units
|
|
||||||
configJson["adc_count"] = 16;
|
|
||||||
// The bit precision of each ADC
|
|
||||||
configJson["cell_precision"] = 2;
|
|
||||||
|
|
||||||
// Crossbar configuration
|
|
||||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
|
||||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
|
||||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
|
||||||
|
|
||||||
// Memory layout of inputs and outputs
|
|
||||||
json::Array inputsAddresses;
|
|
||||||
for (BlockArgument input : funcOp.getArguments())
|
|
||||||
inputsAddresses.push_back(memory.getValueAddress(input));
|
|
||||||
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
|
||||||
|
|
||||||
json::Array outputsAddresses;
|
|
||||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
|
||||||
for (mlir::Value output : returnOp.getOperands())
|
|
||||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
|
||||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
|
||||||
|
|
||||||
auto configPath = (outputDirPath + "/config.json").str();
|
|
||||||
std::error_code errorCode;
|
|
||||||
raw_fd_ostream jsonOS(configPath, errorCode);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
jsonOS << json::Value(std::move(configJson)) << '\n';
|
|
||||||
jsonOS.close();
|
|
||||||
|
|
||||||
return CompilerSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
|
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||||
if (!outputDirPath.empty()) {
|
if (!outputDirPath.empty()) {
|
||||||
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
||||||
@@ -826,33 +834,49 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
|
|
||||||
PimAcceleratorMemory memory;
|
PimAcceleratorMemory memory;
|
||||||
memory.hostMem.allocateHost(moduleOp, funcOp);
|
memory.hostMem.allocateHost(moduleOp, funcOp);
|
||||||
|
memory.reportHost();
|
||||||
|
|
||||||
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
// Write empty host core file
|
if (auto err = writeHostCoreJson(outputDirPath))
|
||||||
std::error_code errorCode;
|
return err;
|
||||||
auto outputHostCorePath = outputDirPath + "/core_0.json";
|
|
||||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
// The host core json contains 2 random instructions, just to make pimsim-nn happy
|
|
||||||
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
|
||||||
hostFileStream.close();
|
|
||||||
|
|
||||||
// For each core, specify the number of crossbar per array group.
|
// For each core, specify the number of crossbar per array group.
|
||||||
// This implementation always assigns one crossbar per group.
|
// This implementation always assigns one crossbar per group.
|
||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t coreCount = 0;
|
size_t maxCoreId = 0;
|
||||||
|
uint64_t nextBatchReportId = 0;
|
||||||
|
|
||||||
// Create Weight Folder
|
// Create Weight Folder
|
||||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||||
|
|
||||||
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||||
auto coreId = coreOp.getCoreId();
|
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
||||||
coreCount++;
|
size_t nextEmittedCoreId = 1;
|
||||||
|
|
||||||
|
for (Operation* op : coreLikeOps) {
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
|
if (!emittedCoreIds.contains(originalCoreId))
|
||||||
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||||
|
if (!emittedCoreIds.contains(originalCoreId))
|
||||||
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : coreLikeOps) {
|
||||||
|
auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||||
|
maxCoreId = std::max(maxCoreId, coreId);
|
||||||
|
|
||||||
std::error_code errorCode;
|
std::error_code errorCode;
|
||||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||||
@@ -863,7 +887,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
}
|
}
|
||||||
coreFileStream << '[';
|
coreFileStream << '[';
|
||||||
|
|
||||||
PimCodeGen coreCodeGen(memory, coreFileStream);
|
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||||
|
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||||
|
|
||||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||||
@@ -871,19 +896,17 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
return CompilerFailure;
|
return CompilerFailure;
|
||||||
assert(processedOperations > 0);
|
assert(processedOperations > 0);
|
||||||
|
|
||||||
// Remove trailing comma, close JSON array
|
|
||||||
coreFileStream.seek(coreFileStream.tell() - 1);
|
coreFileStream.seek(coreFileStream.tell() - 1);
|
||||||
coreFileStream << ']';
|
coreFileStream << ']';
|
||||||
coreFileStream.close();
|
coreFileStream.close();
|
||||||
|
|
||||||
// Write crossbar weights for this core
|
|
||||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||||
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
|
||||||
json::Array xbarsPerGroup;
|
json::Array xbarsPerGroup;
|
||||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
if (index >= coreOp.getWeights().size()) {
|
if (index >= coreOp.getWeights().size()) {
|
||||||
@@ -897,14 +920,48 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||||
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
||||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
|
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
|
||||||
<< '\n';
|
<< "\nError:" << error.message() << '\n';
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
||||||
|
if (temporaryCore)
|
||||||
|
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
|
||||||
|
return CompilerSuccess;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
if (auto err = emitCore(coreOp, false))
|
||||||
|
return err;
|
||||||
|
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||||
|
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||||
|
.getReportRow());
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
SmallVector<int32_t> reportedCoreIds;
|
||||||
|
reportedCoreIds.reserve(batchCoreIds.size());
|
||||||
|
MemoryReportRow batchRow;
|
||||||
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||||
|
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||||
|
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||||
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||||
|
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||||
|
laneResult = emitCore(coreOp, true);
|
||||||
|
if (laneResult == CompilerSuccess)
|
||||||
|
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
|
||||||
|
return laneResult == CompilerSuccess ? success() : failure();
|
||||||
|
})))
|
||||||
|
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||||
|
}
|
||||||
|
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
|
||||||
|
}
|
||||||
|
|
||||||
|
memory.flushReport();
|
||||||
|
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
@@ -14,12 +21,34 @@ struct MemEntry {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MemoryReportRow {
|
||||||
|
uint64_t numAlloca = 0;
|
||||||
|
uint64_t sizeAlloca = 0;
|
||||||
|
uint64_t numGlobal = 0;
|
||||||
|
uint64_t sizeGlobal = 0;
|
||||||
|
|
||||||
|
bool operator==(const MemoryReportRow& other) const {
|
||||||
|
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
|
||||||
|
&& sizeGlobal == other.sizeGlobal;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MemoryReportEntry {
|
||||||
|
enum class Kind {
|
||||||
|
Core,
|
||||||
|
Batch
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = Kind::Core;
|
||||||
|
uint64_t id = 0;
|
||||||
|
llvm::SmallVector<int32_t, 8> coreIds;
|
||||||
|
MemoryReportRow row;
|
||||||
|
};
|
||||||
|
|
||||||
class PimMemory {
|
class PimMemory {
|
||||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||||
|
|
||||||
size_t maxSize = 0; // 0 for unbounded memory
|
|
||||||
size_t startAddress = 0;
|
|
||||||
size_t minAlignment = 4;
|
size_t minAlignment = 4;
|
||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
@@ -33,6 +62,8 @@ public:
|
|||||||
|
|
||||||
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||||
void allocateCore(mlir::Operation* op);
|
void allocateCore(mlir::Operation* op);
|
||||||
|
MemoryReportRow getReportRow() const;
|
||||||
|
void remove(mlir::Value val);
|
||||||
|
|
||||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||||
MemEntry getMemEntry(mlir::Value value) const;
|
MemEntry getMemEntry(mlir::Value value) const;
|
||||||
@@ -45,23 +76,43 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||||
|
std::fstream fileReport;
|
||||||
|
std::optional<MemoryReportRow> hostReportRow;
|
||||||
|
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
: hostMem(memEntriesMap) {}
|
: hostMem(memEntriesMap) {
|
||||||
|
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/reports/";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
std::fstream file(dialectsDir + "/memory_report.txt", std::ios::out);
|
||||||
|
fileReport = std::move(file);
|
||||||
|
}
|
||||||
|
|
||||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||||
|
|
||||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||||
|
void reportHost();
|
||||||
|
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||||
|
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
|
||||||
|
void flushReport();
|
||||||
|
void clean(mlir::Operation* op);
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
llvm::raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreFileStream;
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge);
|
return memory.getValueAddress(value, knowledge);
|
||||||
}
|
}
|
||||||
|
size_t remapCoreId(size_t coreId) const;
|
||||||
|
|
||||||
static llvm::json::Object createEmptyOffset();
|
static llvm::json::Object createEmptyOffset();
|
||||||
void emitInstruction(llvm::json::Object instruction) const;
|
void emitInstruction(llvm::json::Object instruction) const;
|
||||||
@@ -83,15 +134,20 @@ class PimCodeGen {
|
|||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
PimCodeGen(PimAcceleratorMemory& memory,
|
||||||
: memory(memory), coreFileStream(coreJson) {}
|
llvm::raw_fd_ostream& coreJson,
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||||
|
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
||||||
@@ -106,6 +162,7 @@ public:
|
|||||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,3 @@
|
|||||||
/*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
*/
|
|
||||||
|
|
||||||
//===------------------------- PimCompilerOptions.cpp --------------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2022 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
// Compiler Options for PIM
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
@@ -41,7 +28,7 @@ llvm::cl::opt<size_t>
|
|||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||||
@@ -51,7 +38,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
|||||||
"dcp-critical-window-size",
|
"dcp-critical-window-size",
|
||||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||||
llvm::cl::init(1024));
|
llvm::cl::init(4000));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
ignoreConcatError("ignore-concat-error",
|
ignoreConcatError("ignore-concat-error",
|
||||||
|
|||||||
@@ -0,0 +1,220 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct DenseWeightView {
|
||||||
|
DenseElementsAttr denseAttr;
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
int64_t offset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||||
|
strides[index] = strides[index + 1] * shape[index + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||||
|
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||||
|
SmallVector<memref::SubViewOp> subviews;
|
||||||
|
mlir::Value current = weight;
|
||||||
|
memref::GetGlobalOp getGlobalOp;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
Operation* defOp = current.getDefiningOp();
|
||||||
|
if (!defOp)
|
||||||
|
return failure();
|
||||||
|
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||||
|
break;
|
||||||
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||||
|
if (!allStaticSubviewParts(subview))
|
||||||
|
return failure();
|
||||||
|
subviews.push_back(subview);
|
||||||
|
current = subview.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||||
|
current = cast.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseWeightView view;
|
||||||
|
view.denseAttr = denseAttr;
|
||||||
|
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||||
|
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||||
|
|
||||||
|
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||||
|
SmallVector<int64_t> nextStrides;
|
||||||
|
nextStrides.reserve(subview.getStaticStrides().size());
|
||||||
|
for (auto [offset, stride, sourceStride] :
|
||||||
|
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||||
|
view.offset += offset * sourceStride;
|
||||||
|
nextStrides.push_back(stride * sourceStride);
|
||||||
|
}
|
||||||
|
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||||
|
view.strides = std::move(nextStrides);
|
||||||
|
}
|
||||||
|
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||||
|
SmallVector<unsigned, 8> indices;
|
||||||
|
auto addIndex = [&](unsigned weightIndex) {
|
||||||
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
|
indices.push_back(weightIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
|
llvm::sort(indices);
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
|
return getUsedWeightIndices(coreOp.getBody().front());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||||
|
SmallVector<Operation*> coreLikeOps;
|
||||||
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||||
|
coreLikeOps.push_back(&op);
|
||||||
|
return coreLikeOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
|
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||||
|
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||||
|
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||||
|
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||||
|
assert(!error && "Error creating weights directory");
|
||||||
|
size_t indexFileName = 0;
|
||||||
|
|
||||||
|
int64_t xbarSize = crossbarSize.getValue();
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||||
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
|
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||||
|
|
||||||
|
for (Operation* op : coreLikeOps) {
|
||||||
|
auto processCore = [&](pim::PimCoreOp coreOp) {
|
||||||
|
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
|
||||||
|
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||||
|
if (failed(weightView)) {
|
||||||
|
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||||
|
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||||
|
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||||
|
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||||
|
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||||
|
ArrayRef<int64_t> shape = weightView->shape;
|
||||||
|
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||||
|
int64_t numRows = shape[0];
|
||||||
|
int64_t numCols = shape[1];
|
||||||
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
|
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||||
|
assert(errorCode);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t zero = 0;
|
||||||
|
for (int64_t row = 0; row < xbarSize; row++) {
|
||||||
|
for (int64_t col = 0; col < xbarSize; col++) {
|
||||||
|
if (row < numRows && col < numCols) {
|
||||||
|
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||||
|
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||||
|
uint64_t word = bits.getZExtValue();
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
weightFileStream.close();
|
||||||
|
if (globalOp)
|
||||||
|
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||||
|
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
(void) processCore(coreOp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||||
|
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
|
||||||
|
return mapCoreWeightToFileName;
|
||||||
|
}
|
||||||
|
return mapCoreWeightToFileName;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
|
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_pim_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
|
ConversionPatterns.cpp
|
||||||
|
HostFoldability.cpp
|
||||||
|
HostLegality.cpp
|
||||||
|
PrePatterns.cpp
|
||||||
|
PostPatterns.cpp
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Patterns/Math/Elementwise.cpp
|
Patterns/Math/Elementwise.cpp
|
||||||
Patterns/Math/Gemm.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
@@ -18,7 +23,9 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
Common.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
|
Common/ShapeTilingUtils.cpp
|
||||||
|
Common/WeightMaterialization.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -1,279 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageN(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
using HSliceId = size_t;
|
|
||||||
using CoreId = size_t;
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr C ceilIntegerDivide(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return 1 + (ac - 1) / bc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[0] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[1] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
|
||||||
assert(isVectorShape(shape));
|
|
||||||
return shape[0] != 1 ? shape[0] : shape[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto getTensorShape(mlir::Value tensor) {
|
|
||||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isWeightLikeComputeOperand(mlir::Value value) {
|
|
||||||
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
|
|
||||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
|
|
||||||
|
|
||||||
while (auto* definingOp = value.getDefiningOp()) {
|
|
||||||
if (!visited.insert(definingOp).second)
|
|
||||||
return false;
|
|
||||||
if (hasWeightAlways(definingOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
|
|
||||||
value = extractSliceOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
|
|
||||||
value = transposeOp.getData();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(values[Is]...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t>
|
|
||||||
using ValueArg = mlir::Value;
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
struct InvokeWithBlockArgsResult;
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
|
||||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
|
||||||
|
|
||||||
template <typename Fn>
|
|
||||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult =
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
|
||||||
size_t axis,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
|
||||||
tileMatrix(mlir::Value& matrixToTile,
|
|
||||||
int64_t hSliceSize,
|
|
||||||
int64_t vSliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location& loc);
|
|
||||||
|
|
||||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
|
||||||
int64_t length,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||||
|
if (tensors.size() == 1)
|
||||||
|
return tensors[0];
|
||||||
|
|
||||||
|
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||||
|
SmallVector<Value> tensors2;
|
||||||
|
tensors2.reserve(tensors.size() / 2);
|
||||||
|
|
||||||
|
auto* currTensors = &tensors1;
|
||||||
|
auto* nextTensors = &tensors2;
|
||||||
|
while (currTensors->size() > 1) {
|
||||||
|
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||||
|
Value a = (*currTensors)[i];
|
||||||
|
Value b = (*currTensors)[i + 1];
|
||||||
|
rewriter.setInsertionPointAfterValue(b);
|
||||||
|
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||||
|
nextTensors->push_back(addedValue);
|
||||||
|
}
|
||||||
|
if (currTensors->size() % 2 == 1)
|
||||||
|
nextTensors->push_back(currTensors->back());
|
||||||
|
std::swap(currTensors, nextTensors);
|
||||||
|
nextTensors->clear();
|
||||||
|
}
|
||||||
|
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||||
|
return (*currTensors)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(values[Is]...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t>
|
||||||
|
using ValueArg = mlir::Value;
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
struct InvokeWithBlockArgsResult;
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||||
|
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename RewriterT>
|
||||||
|
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||||
|
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||||
|
if (inputs.size() == 1)
|
||||||
|
return inputs.front();
|
||||||
|
|
||||||
|
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||||
|
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
bool concatDimDynamic = false;
|
||||||
|
|
||||||
|
for (mlir::Value input : inputs) {
|
||||||
|
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||||
|
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||||
|
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||||
|
concatDimDynamic = true;
|
||||||
|
else
|
||||||
|
concatDimSize += inputType.getDimSize(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||||
|
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||||
|
/// the body callback reports failure.
|
||||||
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult =
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||||
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+26
-46
@@ -1,24 +1,12 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -44,10 +32,29 @@ SmallVector<Value> sliceTensor(
|
|||||||
|
|
||||||
for (int64_t i = 0; i < numSlices; i++) {
|
for (int64_t i = 0; i < numSlices; i++) {
|
||||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
int64_t currentSliceSize = sliceSize;
|
||||||
|
if (i == numSlices - 1 && lastSliceSize != 0) {
|
||||||
|
currentSliceSize = lastSliceSize;
|
||||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||||
|
}
|
||||||
|
|
||||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
||||||
|
sliceShape[axis] = currentSliceSize;
|
||||||
|
auto sliceType =
|
||||||
|
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||||
|
|
||||||
|
Value slice;
|
||||||
|
if (isHostFoldableValue(tensorToSlice)) {
|
||||||
|
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto sliceCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
||||||
|
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
||||||
|
});
|
||||||
|
slice = sliceCompute.getResult(0);
|
||||||
|
}
|
||||||
slices.push_back(slice);
|
slices.push_back(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,31 +114,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
|
|||||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
} // namespace onnx_mlir
|
||||||
if (tensors.size() == 1)
|
|
||||||
return tensors[0];
|
|
||||||
|
|
||||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
|
||||||
SmallVector<Value> tensors2;
|
|
||||||
tensors2.reserve(tensors.size() / 2);
|
|
||||||
|
|
||||||
auto* currTensors = &tensors1;
|
|
||||||
auto* nextTensors = &tensors2;
|
|
||||||
while (currTensors->size() > 1) {
|
|
||||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
|
||||||
Value a = (*currTensors)[i];
|
|
||||||
Value b = (*currTensors)[i + 1];
|
|
||||||
rewriter.setInsertionPointAfterValue(b);
|
|
||||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
|
||||||
nextTensors->push_back(addedValue);
|
|
||||||
}
|
|
||||||
if (currTensors->size() % 2 == 1)
|
|
||||||
nextTensors->push_back(currTensors->back());
|
|
||||||
std::swap(currTensors, nextTensors);
|
|
||||||
nextTensors->clear();
|
|
||||||
}
|
|
||||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
|
||||||
return (*currTensors)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageN(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
using HSliceId = size_t;
|
||||||
|
using CoreId = size_t;
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr C ceilIntegerDivide(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return 1 + (ac - 1) / bc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[1] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||||
|
assert(isVectorShape(shape));
|
||||||
|
return shape[0] != 1 ? shape[0] : shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline auto getTensorShape(mlir::Value tensor) {
|
||||||
|
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||||
|
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||||
|
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||||
|
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||||
|
&& lhsType.getShape() == rhsType.getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||||
|
/// at most `sliceSize` elements.
|
||||||
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
|
size_t axis,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||||
|
/// current PIM target geometry.
|
||||||
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
|
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
|
/// Tiles a matrix first across output columns and then across input rows so it
|
||||||
|
/// can be assigned to crossbars grouped by core.
|
||||||
|
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||||
|
tileMatrix(mlir::Value& matrixToTile,
|
||||||
|
int64_t hSliceSize,
|
||||||
|
int64_t vSliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location& loc);
|
||||||
|
|
||||||
|
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||||
|
int64_t length,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isWeightLikeComputeOperand(Value value) {
|
||||||
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
|
||||||
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return false;
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
value = transposeOp.getData();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(value))
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!tensorType || !tensorType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
sizes.reserve(tensorType.getRank());
|
||||||
|
for (int64_t dim : tensorType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
|
||||||
|
auto referencedValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||||
|
mapper.map(value, referencedValue.getResult());
|
||||||
|
return referencedValue.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
IRMapping localMapper;
|
||||||
|
for (Value operand : definingOp->getOperands()) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||||
|
localMapper.map(operand, cast<Value>(mapped));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isWeightLikeComputeOperand(operand)) {
|
||||||
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||||
|
if (failed(clonedOperand))
|
||||||
|
return failure();
|
||||||
|
localMapper.map(operand, *clonedOperand);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
localMapper.map(operand, operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||||
|
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapper.map(oldResult, newResult);
|
||||||
|
|
||||||
|
auto mapped = mapper.lookupOrNull(value);
|
||||||
|
if (!mapped)
|
||||||
|
return failure();
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns true when a matrix-valued compute operand is ultimately backed by a
|
||||||
|
/// weight-marked constant/view chain and can be promoted into weights.
|
||||||
|
bool isWeightLikeComputeOperand(mlir::Value value);
|
||||||
|
|
||||||
|
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
|
||||||
|
/// compute body while reusing already-materialized intermediate values.
|
||||||
|
llvm::FailureOr<mlir::Value>
|
||||||
|
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||||
|
patterns.add<removeLRN>(ctx);
|
||||||
|
|
||||||
|
populateElementwisePatterns(patterns, ctx);
|
||||||
|
populateGemmPatterns(patterns, ctx);
|
||||||
|
populateConvPatterns(patterns, ctx);
|
||||||
|
populatePoolPatterns(patterns, ctx);
|
||||||
|
populateReduceMeanPatterns(patterns, ctx);
|
||||||
|
populateReluPatterns(patterns, ctx);
|
||||||
|
populateSigmoidPatterns(patterns, ctx);
|
||||||
|
populateSoftmaxPatterns(patterns, ctx);
|
||||||
|
populateConcatPatterns(patterns, ctx);
|
||||||
|
populateGatherPatterns(patterns, ctx);
|
||||||
|
populateResizePatterns(patterns, ctx);
|
||||||
|
populateReshapePatterns(patterns, ctx);
|
||||||
|
populateSplitPatterns(patterns, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+2
@@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||||
|
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isStaticTensorResult(Operation* op) {
|
||||||
|
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||||
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
|
return shapedType && shapedType.hasStaticShape();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||||
|
if (!op || !visited.insert(op).second)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (!isStaticTensorResult(op))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||||
|
return isHostFoldableValue(transposeOp.getData());
|
||||||
|
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||||
|
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||||
|
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||||
|
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
|
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||||
|
|
||||||
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||||
|
return isHostFoldableValue(extractRowsOp.getInput());
|
||||||
|
|
||||||
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||||
|
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isHostFoldableValue(Value value) {
|
||||||
|
auto* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return isHostFoldableOpImpl(definingOp, visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isHostFoldableOp(Operation* op) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return isHostFoldableOpImpl(op, visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isHostFoldableValue(mlir::Value value);
|
||||||
|
|
||||||
|
bool isHostFoldableOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
|
||||||
|
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||||
|
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||||
|
continue;
|
||||||
|
if (isHostFoldableOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
@@ -7,25 +8,18 @@
|
|||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
#include "Common/Common.hpp"
|
||||||
#include <iterator>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -33,12 +27,8 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool haveSameStaticShape(Value lhs, Value rhs);
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
|
||||||
|
|
||||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||||
@@ -48,33 +38,64 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
|
|
||||||
private:
|
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
|
||||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
IRMapping mapper;
|
||||||
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||||
|
if (!computes.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
|
||||||
|
SmallVector<Type> sourceTypes;
|
||||||
|
SmallVector<Location> sourceLocs;
|
||||||
|
sourceTypes.reserve(funcOp.getNumArguments());
|
||||||
|
sourceLocs.reserve(funcOp.getNumArguments());
|
||||||
|
for (Value source : funcOp.getArguments()) {
|
||||||
|
sourceTypes.push_back(source.getType());
|
||||||
|
sourceLocs.push_back(source.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newCompute = spatial::SpatCompute::create(
|
||||||
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||||
|
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||||
|
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||||
|
mapper.map(computeArg, blockArg);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
for (Operation& op : funcOp.getOps())
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||||
|
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||||
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
||||||
|
op.dropAllUses();
|
||||||
|
rewriter.eraseOp(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||||
|
returnOp.setOperand(index, computeResult);
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
RewritePatternSet mergeActivationPatterns(ctx);
|
RewritePatternSet prePatterns(ctx);
|
||||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
populatePrePatterns(prePatterns, ctx);
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
|
||||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
|
||||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
|
||||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
|
||||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
|
||||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
|
||||||
|
|
||||||
IRRewriter rewriter(moduleOp);
|
|
||||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (failed(entryFunc)) {
|
if (failed(entryFunc)) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -87,8 +108,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
target.addIllegalOp<ONNXMatMulOp>();
|
||||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
|
||||||
target.addIllegalOp<ONNXAddOp>();
|
target.addIllegalOp<ONNXAddOp>();
|
||||||
target.addIllegalOp<ONNXDivOp>();
|
target.addIllegalOp<ONNXDivOp>();
|
||||||
target.addIllegalOp<ONNXMulOp>();
|
target.addIllegalOp<ONNXMulOp>();
|
||||||
@@ -107,32 +127,23 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||||
target.addIllegalOp<ONNXSplitOp>();
|
target.addIllegalOp<ONNXSplitOp>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet conversionPatterns(ctx);
|
||||||
patterns.add<removeLRN>(ctx);
|
populateConversionPatterns(conversionPatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||||
populateElementwisePatterns(patterns, ctx);
|
signalPassFailure();
|
||||||
populateGemmPatterns(patterns, ctx);
|
return;
|
||||||
populateConvPatterns(patterns, ctx);
|
}
|
||||||
populatePoolPatterns(patterns, ctx);
|
|
||||||
populateReduceMeanPatterns(patterns, ctx);
|
RewritePatternSet earlyPostPatterns(ctx);
|
||||||
populateReluPatterns(patterns, ctx);
|
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||||
populateSigmoidPatterns(patterns, ctx);
|
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||||
populateSoftmaxPatterns(patterns, ctx);
|
|
||||||
populateConcatPatterns(patterns, ctx);
|
|
||||||
populateGatherPatterns(patterns, ctx);
|
|
||||||
populateResizePatterns(patterns, ctx);
|
|
||||||
populateReshapePatterns(patterns, ctx);
|
|
||||||
populateSplitPatterns(patterns, ctx);
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count the number of compute ops and check they do not exceed the core count
|
|
||||||
if (coresCount != -1) {
|
if (coresCount != -1) {
|
||||||
int computeOpsCount = 0;
|
int computeOpsCount = 0;
|
||||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||||
if (isa<spatial::SpatCompute>(op))
|
if (isa<spatial::SpatCompute>(op))
|
||||||
computeOpsCount++;
|
computeOpsCount++;
|
||||||
|
|
||||||
@@ -149,334 +160,24 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
RewritePatternSet postPatterns(ctx);
|
||||||
|
populatePostPatterns(postPatterns, ctx);
|
||||||
|
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
mergeTriviallyConnectedComputes(*entryFunc);
|
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
|
||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
|
||||||
Value source = funcSource(toRemoveOp);
|
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
||||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
mapper.map(source, BB->getArgument(0));
|
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|
||||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
|
||||||
auto sources = toRemoveOp.getInputs();
|
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
||||||
if (llvm::any_of(
|
|
||||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
|
||||||
SmallVector<Type> sourceTypes;
|
|
||||||
SmallVector<Location> sourceLoc;
|
|
||||||
for (auto source : sources) {
|
|
||||||
sourceTypes.push_back(source.getType());
|
|
||||||
sourceLoc.push_back(loc);
|
|
||||||
}
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
|
||||||
mapper.map(source, bbArg);
|
|
||||||
auto newConcat = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(value))
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
|
||||||
if (!tensorType || !tensorType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
sizes.reserve(tensorType.getRank());
|
|
||||||
for (int64_t dim : tensorType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
|
|
||||||
auto referencedValue =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
|
||||||
mapper.map(value, referencedValue.getResult());
|
|
||||||
return referencedValue.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
IRMapping localMapper;
|
|
||||||
for (Value operand : definingOp->getOperands()) {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
|
||||||
localMapper.map(operand, cast<Value>(mapped));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isWeightLikeComputeOperand(operand)) {
|
|
||||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
|
||||||
if (failed(clonedOperand))
|
|
||||||
return failure();
|
|
||||||
localMapper.map(operand, *clonedOperand);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
localMapper.map(operand, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
|
|
||||||
auto mapped = mapper.lookupOrNull(value);
|
|
||||||
if (!mapped)
|
|
||||||
return failure();
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
bool keep = true;
|
|
||||||
while (keep) {
|
|
||||||
keep = false;
|
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
|
||||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<ONNXTransposeOp>(
|
|
||||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
|
||||||
|
|
||||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
|
||||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
|
||||||
if (compute->hasOneUse()) {
|
|
||||||
auto& use = *compute->getUses().begin();
|
|
||||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
|
||||||
|
|
||||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
|
||||||
trivialComputes.push_back(compute);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!trivialComputes.empty()) {
|
|
||||||
auto compute = trivialComputes.front();
|
|
||||||
|
|
||||||
if (compute.use_empty()) {
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& computeUse = *compute->getUses().begin();
|
|
||||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
|
||||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
|
||||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
|
||||||
|
|
||||||
auto newCompute =
|
|
||||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto weightMutableIter = newCompute.getWeightsMutable();
|
|
||||||
for (auto weight : child.getWeights()) {
|
|
||||||
auto founded = llvm::find(newCompute.getWeights(), weight);
|
|
||||||
if (founded == newCompute.getWeights().end()) {
|
|
||||||
weightMutableIter.append(weight);
|
|
||||||
auto last = weightMutableIter.end();
|
|
||||||
last = std::prev(last, 1);
|
|
||||||
mapper.map(weight, last->get());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
mapper.map(weight, *founded);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
|
||||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
|
||||||
newTerminator->erase();
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
|
||||||
for (auto& op : child.getBody().front()) {
|
|
||||||
auto newInst = rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
|
||||||
toErase.insert(child);
|
|
||||||
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
toErase.insert(compute);
|
|
||||||
|
|
||||||
if (newCompute->hasOneUse()) {
|
|
||||||
auto& use = *newCompute->getUses().begin();
|
|
||||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
|
||||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
|
||||||
trivialComputes.push_back(newCompute);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto compute : toErase) {
|
|
||||||
for (Value result : compute->getResults())
|
|
||||||
result.dropAllUses();
|
|
||||||
compute.erase();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
|
||||||
bool isAlwaysWeight =
|
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
|
||||||
if (isAlwaysWeight)
|
|
||||||
markWeightAlways(constantOp);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
|
||||||
|
|
||||||
for (auto compute : computes) {
|
|
||||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
|
||||||
bool needsRewrite = false;
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
|
||||||
continue;
|
|
||||||
promoteInput[inputIdx] = true;
|
|
||||||
needsRewrite = true;
|
|
||||||
}
|
|
||||||
if (!needsRewrite)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
|
||||||
|
|
||||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
|
||||||
SmallVector<Value> newInputs;
|
|
||||||
SmallVector<Type> newInputTypes;
|
|
||||||
SmallVector<Location> newInputLocs;
|
|
||||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
||||||
newInputs.reserve(compute.getInputs().size());
|
|
||||||
newInputTypes.reserve(compute.getInputs().size());
|
|
||||||
newInputLocs.reserve(compute.getInputs().size());
|
|
||||||
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (promoteInput[inputIdx]) {
|
|
||||||
newWeights.push_back(input);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newInputs.push_back(input);
|
|
||||||
newInputTypes.push_back(input.getType());
|
|
||||||
newInputLocs.push_back(input.getLoc());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newCompute =
|
|
||||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
|
||||||
auto* newBlock =
|
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto& oldBlock = compute.getBody().front();
|
|
||||||
size_t newInputIdx = 0;
|
|
||||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
|
||||||
if (!promoteInput[oldInputIdx]) {
|
|
||||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
|
||||||
if (failed(clonedValue))
|
|
||||||
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
|
||||||
mapper.map(oldArg, *clonedValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& op : oldBlock.without_terminator())
|
|
||||||
rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
||||||
SmallVector<Value> newYieldOperands;
|
|
||||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
|
||||||
for (Value operand : oldYield.getOutputs()) {
|
|
||||||
auto mapped = mapper.lookupOrNull(operand);
|
|
||||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
|
||||||
}
|
|
||||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
|
||||||
|
|
||||||
compute.replaceAllUsesWith(newCompute);
|
|
||||||
compute.erase();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -7,11 +7,10 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -147,11 +146,11 @@ static Value buildPackedBias(bool hasBias,
|
|||||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
static Value createIm2colRowComputes(Value x,
|
||||||
RankedTensorType xType,
|
RankedTensorType xType,
|
||||||
RankedTensorType im2colType,
|
RankedTensorType im2colType,
|
||||||
RankedTensorType im2colRowType,
|
RankedTensorType im2colRowType,
|
||||||
RankedTensorType gemmInputRowType,
|
RankedTensorType gemmInputRowsType,
|
||||||
int64_t batchSize,
|
int64_t batchSize,
|
||||||
int64_t numChannelsIn,
|
int64_t numChannelsIn,
|
||||||
int64_t xHeight,
|
int64_t xHeight,
|
||||||
@@ -176,8 +175,8 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
|
|||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
auto im2colComputeOp =
|
||||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
|
|
||||||
// Pad input with zeros if needed:
|
// Pad input with zeros if needed:
|
||||||
@@ -285,24 +284,10 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> rowResults;
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||||
rowResults.reserve(packedNumRows);
|
|
||||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(packFactor * patchSize)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
rowResults.push_back(
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
|
||||||
}
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
SmallVector<Value> rows;
|
return im2colComputeOp.getResult(0);
|
||||||
rows.reserve(im2colComputeOp.getNumResults());
|
|
||||||
for (Value result : im2colComputeOp.getResults())
|
|
||||||
rows.push_back(result);
|
|
||||||
return rows;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||||
@@ -320,16 +305,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||||
Value gemmOut;
|
Value gemmOut;
|
||||||
if (packFactor == 1) {
|
if (packFactor == 1) {
|
||||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
Value packedOutput =
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
gemmRowArgs.size() == 1
|
|
||||||
? gemmRowArgs.front()
|
|
||||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
expandedType,
|
expandedType,
|
||||||
@@ -388,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
auto wType = cast<RankedTensorType>(w.getType());
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
if (!xType.hasStaticShape()) {
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||||
|
return failure();
|
||||||
// We need to understand what is group
|
}
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
if (!wType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (xType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (outType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (convOp.getGroup() != 1) {
|
||||||
|
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
@@ -409,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const auto dilationsAttr = convOp.getDilations();
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
const auto padsAttr = convOp.getPads();
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
if (stridesAttr && stridesAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (padsAttr && padsAttr->size() != 4) {
|
||||||
|
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
@@ -449,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
padWidthBegin = totalPadW - padWidthEnd;
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -505,17 +526,21 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||||
//
|
//
|
||||||
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
|
// We want to process N pixels at the same time. Instead of doing N separate operations
|
||||||
// the row it needs instead of receiving a full packed tensor and slicing it locally.
|
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||||
auto gemmInputRowType =
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||||
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||||
auto gemmOutputRowType =
|
// B_packed: [N * patchSize, N * cOut]
|
||||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||||
|
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||||
|
auto gemmOutputRowsType =
|
||||||
|
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
|
Value gemmInputRows = createIm2colRowComputes(x,
|
||||||
xType,
|
xType,
|
||||||
im2colType,
|
im2colType,
|
||||||
rowType,
|
rowType,
|
||||||
gemmInputRowType,
|
gemmInputRowsType,
|
||||||
batchSize,
|
batchSize,
|
||||||
numChannelsIn,
|
numChannelsIn,
|
||||||
xHeight,
|
xHeight,
|
||||||
@@ -552,13 +577,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
Value gemmC = buildPackedBias(
|
Value gemmC = buildPackedBias(
|
||||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
|
||||||
SmallVector<Value> gemmRows;
|
Value gemmRows = ONNXGemmOp::create(rewriter,
|
||||||
gemmRows.reserve(gemmInputRows.size());
|
|
||||||
for (Value gemmInputRow : gemmInputRows) {
|
|
||||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
loc,
|
||||||
gemmOutputRowType,
|
gemmOutputRowsType,
|
||||||
gemmInputRow,
|
gemmInputRows,
|
||||||
gemmB,
|
gemmB,
|
||||||
gemmC,
|
gemmC,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
@@ -566,11 +588,9 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
rewriter.getBoolAttr(false),
|
rewriter.getBoolAttr(false),
|
||||||
rewriter.getBoolAttr(false))
|
rewriter.getBoolAttr(false))
|
||||||
.getY();
|
.getY();
|
||||||
gemmRows.push_back(gemmRow);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(convOp,
|
rewriter.replaceOp(convOp,
|
||||||
createCollectedConvOutput(gemmRows,
|
createCollectedConvOutput(ValueRange {gemmRows},
|
||||||
convOp.getType(),
|
convOp.getType(),
|
||||||
gemmOutType,
|
gemmOutType,
|
||||||
nhwcType,
|
nhwcType,
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -15,13 +16,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
|
||||||
strides[i] = strides[i + 1] * shape[i + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
|||||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value transposeForSpatial(Value value,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<int64_t> permutation,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
resultType,
|
||||||
|
value,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
resultType,
|
||||||
|
input,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -65,6 +105,72 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||||
|
RankedTensorType matrixType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t numRows = matrixType.getDimSize(0);
|
||||||
|
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||||
|
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||||
|
|
||||||
|
if (isHostFoldableValue(matrix)) {
|
||||||
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||||
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto buildRowSlices = [&](Value matrixArg) {
|
||||||
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||||
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cloneBatchInputChainIntoSliceCompute =
|
||||||
|
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
||||||
|
auto sliceCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
||||||
|
Value transformedMatrix = input;
|
||||||
|
if (!chainOps.empty()) {
|
||||||
|
IRMapping mapper;
|
||||||
|
mapper.map(rootValue, input);
|
||||||
|
for (Operation* chainOp : chainOps)
|
||||||
|
rewriter.clone(*chainOp, mapper);
|
||||||
|
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
||||||
|
});
|
||||||
|
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
||||||
|
return rowSlices;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<Operation*> chainOps;
|
||||||
|
Value rootValue = matrix;
|
||||||
|
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
||||||
|
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
||||||
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||||
|
return cloneBatchInputChainIntoSliceCompute(
|
||||||
|
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (definingOp->getNumOperands() != 1)
|
||||||
|
break;
|
||||||
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||||
|
break;
|
||||||
|
|
||||||
|
chainOps.push_back(definingOp);
|
||||||
|
rootValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||||
|
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
@@ -75,13 +181,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
Value b = gemmOpAdaptor.getB();
|
Value b = gemmOpAdaptor.getB();
|
||||||
Value c = gemmOpAdaptor.getC();
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = aType.getDimSize(0);
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
|
||||||
@@ -105,47 +221,43 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
if (cType.getRank() == 1) {
|
if (cType.getRank() == 1) {
|
||||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
c = tensor::ExpandShapeOp::create(rewriter,
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||||
loc,
|
|
||||||
expandedType,
|
|
||||||
c,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||||
|
SmallVector<Value> cSlices;
|
||||||
|
if (hasC && cHasNumOutRows)
|
||||||
|
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||||
|
|
||||||
SmallVector<Value> gemvOps;
|
SmallVector<Value> gemvOps;
|
||||||
gemvOps.reserve(numOutRows);
|
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
|
||||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
|
||||||
|
|
||||||
Value cSlice = c;
|
Value cSlice = c;
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
if (cHasNumOutRows) {
|
if (cHasNumOutRows)
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
else if (!isVectorShape(getTensorShape(c))) {
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
return failure();
|
||||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outRowType,
|
outRowType,
|
||||||
aSlice,
|
aSlices[static_cast<size_t>(rowIdx)],
|
||||||
b,
|
b,
|
||||||
cSlice,
|
cSlice,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
@@ -156,8 +268,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
@@ -189,20 +300,31 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
if (cType.getRank() == 1) {
|
if (cType.getRank() == 1) {
|
||||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
c = tensor::ExpandShapeOp::create(rewriter,
|
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||||
gemmLoc,
|
|
||||||
expandedType,
|
|
||||||
c,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
if (!aType.hasStaticShape()) {
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||||
// Not a gemv
|
// Not a gemv
|
||||||
@@ -210,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
if (transA) {
|
if (transA) {
|
||||||
auto aShape = aType.getShape();
|
auto aShape = aType.getShape();
|
||||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||||
|
aType = cast<RankedTensorType>(a.getType());
|
||||||
}
|
}
|
||||||
if (transB) {
|
if (transB) {
|
||||||
auto bShape = bType.getShape();
|
auto bShape = bType.getShape();
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||||
auto bNumVSlices = aNumHSlices;
|
auto bNumVSlices = aNumHSlices;
|
||||||
auto bLastVSliceSize = aLastHSliceSize;
|
|
||||||
auto cNumHSlices = bNumHSlices;
|
auto cNumHSlices = bNumHSlices;
|
||||||
auto cLastHSliceSize = bLastHSliceSize;
|
auto cLastHSliceSize = bLastHSliceSize;
|
||||||
auto outNumHSlices = cNumHSlices;
|
auto outNumHSlices = cNumHSlices;
|
||||||
@@ -281,19 +403,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp = createSpatCompute(
|
auto computeOp = createSpatCompute(
|
||||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||||
SmallVector<Value> vmmOutputs;
|
SmallVector<Value> vmmOutputs;
|
||||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||||
vmmOutputs.push_back(
|
vmmOutputs.push_back(
|
||||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
if (vmmOutputs.empty()) {
|
||||||
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
|
return success();
|
||||||
});
|
});
|
||||||
|
if (failed(computeOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
partialResults.push_back(computeOp->getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
@@ -313,8 +441,126 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = gemmOp.getLoc();
|
||||||
|
Value a = gemmOpAdaptor.getA();
|
||||||
|
Value b = gemmOpAdaptor.getB();
|
||||||
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
|
auto bType = cast<RankedTensorType>(b.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
if (numOutRows <= 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
||||||
|
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
||||||
|
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledB))
|
||||||
|
return failure();
|
||||||
|
b = *scaledB;
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
|
||||||
|
if (gemmOpAdaptor.getTransB()) {
|
||||||
|
auto bShape = bType.getShape();
|
||||||
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||||
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
}
|
||||||
|
(void) bType;
|
||||||
|
|
||||||
|
if (!isHostFoldableValue(b))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value sharedBias;
|
||||||
|
if (hasC) {
|
||||||
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledC))
|
||||||
|
return failure();
|
||||||
|
c = *scaledC;
|
||||||
|
auto cType = cast<RankedTensorType>(c.getType());
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
}
|
||||||
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||||
|
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||||
|
return failure();
|
||||||
|
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||||
|
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
||||||
|
sharedBias = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||||
|
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||||
|
|
||||||
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||||
|
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||||
|
|
||||||
|
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange(resultTypes),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||||
|
ValueRange(weights),
|
||||||
|
ValueRange(aSlices));
|
||||||
|
|
||||||
|
Block* body = rewriter.createBlock(
|
||||||
|
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||||
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
|
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||||
|
Value laneResult = vmmResult;
|
||||||
|
if (sharedBias)
|
||||||
|
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||||
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
@@ -322,6 +568,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
||||||
patterns.insert<GemmToManyGemv>(ctx);
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,9 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -14,7 +15,102 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchMatrix(Value value,
|
||||||
|
int64_t batchIndex,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t cols,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 2)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||||
|
auto buildMatrix = [&](Value input) -> Value {
|
||||||
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
matrixType,
|
||||||
|
slice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildMatrix(value);
|
||||||
|
|
||||||
|
auto batchMatrixCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||||
|
});
|
||||||
|
return batchMatrixCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
RankedTensorType transposedType;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
perm = {1, 0};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
perm = {0, 2, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||||
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
for (Value input : inputs)
|
||||||
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||||
|
outputShape[axis] = concatDimSize;
|
||||||
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
|
||||||
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||||
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||||
|
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||||
|
});
|
||||||
|
return concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
@@ -24,80 +120,125 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||||
|
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||||
|
return failure();
|
||||||
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||||
|
|| !haveStaticPositiveShape(outType.getShape()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t batch = rhsType.getDimSize(0);
|
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||||
const int64_t k = rhsType.getDimSize(1);
|
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||||
const int64_t n = rhsType.getDimSize(2);
|
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||||
const int64_t m = lhsType.getDimSize(0);
|
|
||||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||||
|| outType.getDimSize(2) != n)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||||
|
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||||
|
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||||
|
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||||
|
if (k != rhsK)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (outType.getRank() == 2) {
|
||||||
|
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
|
||||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
|
||||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
|
||||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
|
||||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
|
||||||
|
|
||||||
Value lhsTransposed =
|
Value lhs = matmulOp.getA();
|
||||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
Value rhs = matmulOp.getB();
|
||||||
|
int64_t lhsBatchForGemm = lhsBatch;
|
||||||
|
int64_t rhsBatchForGemm = rhsBatch;
|
||||||
|
int64_t gemmM = m;
|
||||||
|
int64_t gemmK = k;
|
||||||
|
int64_t gemmN = n;
|
||||||
|
if (useTransposedForm) {
|
||||||
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||||
|
lhsBatchForGemm = rhsBatch;
|
||||||
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
|
rhsBatchForGemm = lhsBatch;
|
||||||
|
gemmM = n;
|
||||||
|
gemmN = m;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||||
|
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
SmallVector<Value> gemmRows;
|
if (outType.getRank() == 2) {
|
||||||
gemmRows.reserve(batch * n);
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
SmallVector<OpFoldResult> offsets = {
|
|
||||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
|
||||||
SmallVector<OpFoldResult> strides = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Value rhsSlice =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
|
||||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
loc,
|
||||||
rhsRowType,
|
gemmType,
|
||||||
rhsSlice,
|
lhsMatrix,
|
||||||
SmallVector<ReassociationIndices> {
|
rhsMatrix,
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
|
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
gemmRowType,
|
|
||||||
rhsRow,
|
|
||||||
lhsTransposed,
|
|
||||||
none,
|
none,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getBoolAttr(false),
|
rewriter.getBoolAttr(false),
|
||||||
rewriter.getBoolAttr(false));
|
rewriter.getBoolAttr(false))
|
||||||
gemmRows.push_back(gemmOp.getY());
|
.getY();
|
||||||
}
|
if (useTransposedForm) {
|
||||||
}
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
});
|
});
|
||||||
|
gemmResult = transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
Value gemmOut = concatComputeOp.getResult(0);
|
SmallVector<Value> batchResults;
|
||||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
batchResults.reserve(batch);
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
|
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
|
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
gemmExpandedType,
|
gemmType,
|
||||||
gemmOut,
|
lhsMatrix,
|
||||||
|
rhsMatrix,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
auto batchResultCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||||
|
Value resultMatrix = input;
|
||||||
|
if (useTransposedForm) {
|
||||||
|
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||||
|
input,
|
||||||
|
rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
}
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
batchedOutType,
|
||||||
|
resultMatrix,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0, 1},
|
{0, 1},
|
||||||
{2}
|
{2}
|
||||||
});
|
});
|
||||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
batchResults.push_back(batchResultCompute.getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -106,7 +247,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
patterns.insert<MatMulToGemm>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||||
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
for (Value input : inputs)
|
||||||
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||||
|
outputShape[axis] = concatDimSize;
|
||||||
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
|
||||||
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||||
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||||
|
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||||
|
});
|
||||||
|
return concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static Value buildReduceMeanKeepdims(Value input,
|
static Value buildReduceMeanKeepdims(Value input,
|
||||||
ArrayRef<bool> reducedAxes,
|
ArrayRef<bool> reducedAxes,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
@@ -100,8 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
|||||||
for (Value slice : slices)
|
for (Value slice : slices)
|
||||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||||
|
|
||||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||||
@@ -116,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
|||||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensor::CollapseShapeOp::create(
|
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
if (isHostFoldableValue(keepdimsValue))
|
||||||
.getResult();
|
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||||
|
|
||||||
|
auto squeezeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
|
||||||
|
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
|
||||||
|
});
|
||||||
|
return squeezeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/APFloat.h"
|
||||||
|
#include "llvm/ADT/APInt.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -31,13 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
|||||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
|
||||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
|
||||||
if (values.size() == 1)
|
|
||||||
return values.front();
|
|
||||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
@@ -52,27 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
|||||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ReduceOp>
|
static Value createPoolFillElement(
|
||||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
if (!useMinimumValue)
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||||
|
|
||||||
Value reduced = windowValues.front();
|
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||||
for (Value value : windowValues.drop_front())
|
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||||
return reduced;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||||
if (divisor == 1)
|
}
|
||||||
return reducedWindow;
|
|
||||||
|
|
||||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
llvm_unreachable("unsupported pool element type");
|
||||||
double scale = 1.0 / static_cast<double>(divisor);
|
}
|
||||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
|
||||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
static Value createPoolFillTensor(
|
||||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
|
||||||
|
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||||
|
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PoolOp>
|
||||||
|
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
PoolOp poolOp,
|
||||||
|
Value input,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
int64_t padTop,
|
||||||
|
int64_t padLeft,
|
||||||
|
int64_t padBottom,
|
||||||
|
int64_t padRight) {
|
||||||
|
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
|
||||||
|
return input;
|
||||||
|
|
||||||
|
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
|
||||||
|
inputType.getDimSize(1),
|
||||||
|
inputType.getDimSize(2) + padTop + padBottom,
|
||||||
|
inputType.getDimSize(3) + padLeft + padRight},
|
||||||
|
inputType.getElementType(),
|
||||||
|
inputType.getEncoding());
|
||||||
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padTop),
|
||||||
|
rewriter.getIndexAttr(padLeft)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padBottom),
|
||||||
|
rewriter.getIndexAttr(padRight)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int index = 0; index < paddedType.getRank(); ++index)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
Value padValue = createPoolFillElement(
|
||||||
|
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
|
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
return padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Operation* op,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t channels,
|
||||||
|
int64_t inputHeight,
|
||||||
|
int64_t inputWidth,
|
||||||
|
int64_t outputHeight,
|
||||||
|
int64_t outputWidth,
|
||||||
|
int64_t kernelHeight,
|
||||||
|
int64_t kernelWidth,
|
||||||
|
int64_t strideHeight,
|
||||||
|
int64_t strideWidth,
|
||||||
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
|
int64_t padTop,
|
||||||
|
int64_t padLeft,
|
||||||
|
bool countIncludePad) {
|
||||||
|
auto elemType = dyn_cast<FloatType>(outType.getElementType());
|
||||||
|
if (!elemType) {
|
||||||
|
op->emitOpError("AveragePool lowering requires a floating-point element type");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
|
||||||
|
SmallVector<Attribute> scaleValues;
|
||||||
|
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
|
||||||
|
for (int64_t channel = 0; channel < channels; ++channel) {
|
||||||
|
(void) channel;
|
||||||
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||||
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||||
|
int64_t validCount = 0;
|
||||||
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||||
|
if (inH < 0 || inH >= inputHeight)
|
||||||
|
continue;
|
||||||
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||||
|
if (inW < 0 || inW >= inputWidth)
|
||||||
|
continue;
|
||||||
|
++validCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
|
||||||
|
if (divisor <= 0) {
|
||||||
|
op->emitOpError("AveragePool divisor must be positive");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PoolOp>
|
template <typename PoolOp>
|
||||||
@@ -150,49 +244,90 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) padBottom;
|
|
||||||
(void) padRight;
|
|
||||||
|
|
||||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||||
|
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
|
||||||
|
const bool countIncludePad = [&]() {
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
|
||||||
|
return poolOp.getCountIncludePad() == 1;
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
Value averageScaleTensor;
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
|
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
|
||||||
|
loc,
|
||||||
|
poolOp,
|
||||||
|
outType,
|
||||||
|
channels,
|
||||||
|
inputHeight,
|
||||||
|
inputWidth,
|
||||||
|
outputHeight,
|
||||||
|
outputWidth,
|
||||||
|
kernelHeight,
|
||||||
|
kernelWidth,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
padTop,
|
||||||
|
padLeft,
|
||||||
|
countIncludePad);
|
||||||
|
if (failed(maybeAverageScaleTensor))
|
||||||
|
return failure();
|
||||||
|
averageScaleTensor = *maybeAverageScaleTensor;
|
||||||
|
}
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||||
SmallVector<Value> batchResults;
|
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||||
batchResults.reserve(batchSize);
|
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||||
|
|
||||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
SmallVector<Value> rows;
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
rows.reserve(outputHeight);
|
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||||
|
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||||
|
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||||
|
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||||
|
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||||
|
|
||||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||||
SmallVector<Value> rowPixels;
|
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||||
rowPixels.reserve(outputWidth);
|
|
||||||
|
|
||||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
Value outputPatchIndex = outputLoop.getInductionVar();
|
||||||
SmallVector<Value> outputChannelTiles;
|
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
|
||||||
outputChannelTiles.reserve(channelTileCount);
|
|
||||||
|
|
||||||
|
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||||
|
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||||
|
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||||
|
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||||
|
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||||
|
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||||
|
|
||||||
|
Value updatedOutput = pooledOutputAcc;
|
||||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
|
Value reducedWindow = createPoolFillTensor(
|
||||||
|
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
|
|
||||||
SmallVector<Value> windowValues;
|
|
||||||
windowValues.reserve(kernelHeight * kernelWidth);
|
|
||||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
Value paddedInH = windowBaseH;
|
||||||
if (inH < 0 || inH >= inputHeight)
|
if (kernelH * dilationHeight != 0) {
|
||||||
continue;
|
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||||
|
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
Value paddedInW = windowBaseW;
|
||||||
if (inW < 0 || inW >= inputWidth)
|
if (kernelW * dilationWidth != 0) {
|
||||||
continue;
|
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||||
|
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
SmallVector<OpFoldResult> offsets = {batchIndex,
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
rewriter.getIndexAttr(inH),
|
paddedInH,
|
||||||
rewriter.getIndexAttr(inW)};
|
paddedInW};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(tileChannels),
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
@@ -202,37 +337,51 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
Value windowValue =
|
Value windowValue =
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
windowValues.push_back(windowValue);
|
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (windowValues.empty())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
|
||||||
|
|
||||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
|
||||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
||||||
const int64_t divisor =
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
outHeightIndex,
|
||||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
outWidthIndex};
|
||||||
|
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(tileChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||||
|
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||||
|
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
outputChannelTiles.push_back(reducedWindow);
|
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
|
||||||
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
|
outHeightIndex,
|
||||||
|
outWidthIndex};
|
||||||
|
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(tileChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
updatedOutput = tensor::InsertSliceOp::create(
|
||||||
|
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||||
}
|
}
|
||||||
|
|
||||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||||
}
|
|
||||||
|
|
||||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
rewriter.setInsertionPointAfter(outputLoop);
|
||||||
}
|
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
|
||||||
|
|
||||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
if (failed(computeOp))
|
if (failed(computeOp))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||||
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
for (Value input : inputs)
|
||||||
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||||
|
outputShape[axis] = concatDimSize;
|
||||||
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
|
||||||
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||||
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||||
|
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||||
|
});
|
||||||
|
return concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
@@ -47,8 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
|||||||
for (Value slice : slices)
|
for (Value slice : slices)
|
||||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||||
|
|
||||||
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||||
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||||
@@ -93,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
Value transposedInput = preTransposeCompute.getResult(0);
|
Value transposedInput = preTransposeCompute.getResult(0);
|
||||||
Value transposedResult = buildSoftmax(
|
Value transposedResult = buildSoftmax(
|
||||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||||
result = ONNXTransposeOp::create(
|
auto postTransposeCompute =
|
||||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(
|
||||||
|
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||||
|
});
|
||||||
|
result = postTransposeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(softmaxOp, result);
|
rewriter.replaceOp(softmaxOp, result);
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -17,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
auto inputs = adaptor.getInputs();
|
auto inputs = adaptor.getInputs();
|
||||||
int64_t axis = adaptor.getAxis();
|
int64_t axis = adaptor.getAxis();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||||
|
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute(
|
||||||
|
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(
|
||||||
|
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
|
|||||||
}
|
}
|
||||||
if (slices.empty())
|
if (slices.empty())
|
||||||
return {};
|
return {};
|
||||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
return createSpatConcat(rewriter, loc, axis, slices);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||||
}
|
}
|
||||||
result = rows.size() == 1
|
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
||||||
? rows.front()
|
|
||||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<ReassociationIndices> reassociation;
|
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||||
if (sourceType.getRank() > resultType.getRank()
|
if (isHostFoldableValue(adaptor.getData())) {
|
||||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sourceType.getRank() < resultType.getRank()
|
auto computeOp = createSpatCompute<1>(
|
||||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
Value reshaped = buildReshape(data);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
};
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
|
if (sourceType.getRank() > resultType.getRank()
|
||||||
|
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||||
|
return replaceWithReshape([&](Value data) {
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sourceType.getRank() < resultType.getRank()
|
||||||
|
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||||
|
return replaceWithReshape([&](Value data) {
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||||
|
});
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
|
|||||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||||
}
|
}
|
||||||
|
|
||||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
return createSpatConcat(rewriter, loc, axis, slices);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -23,7 +25,10 @@ static Value extractSliceAt(
|
|||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
SmallVector<int64_t> resultShape(inputType.getShape());
|
||||||
|
resultShape[axis] = size;
|
||||||
|
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||||
@@ -44,23 +49,42 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
outputs.reserve(splitOp.getNumResults());
|
outputs.reserve(splitOp.getNumResults());
|
||||||
|
|
||||||
int64_t offset = 0;
|
int64_t offset = 0;
|
||||||
|
SmallVector<RankedTensorType> resultTypes;
|
||||||
|
resultTypes.reserve(splitOp.getNumResults());
|
||||||
|
SmallVector<int64_t> sliceSizes;
|
||||||
|
sliceSizes.reserve(splitOp.getNumResults());
|
||||||
for (Value result : splitOp.getResults()) {
|
for (Value result : splitOp.getResults()) {
|
||||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
int64_t sliceSize = resultType.getShape()[axis];
|
resultTypes.push_back(resultType);
|
||||||
auto computeOp =
|
sliceSizes.push_back(resultType.getShape()[axis]);
|
||||||
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
|
|
||||||
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
|
|
||||||
});
|
|
||||||
outputs.push_back(computeOp.getResult(0));
|
|
||||||
offset += sliceSize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isHostFoldableValue(adaptor.getInput())) {
|
||||||
|
for (int64_t sliceSize : sliceSizes) {
|
||||||
|
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||||
|
offset += sliceSize;
|
||||||
|
}
|
||||||
rewriter.replaceOp(splitOp, outputs);
|
rewriter.replaceOp(splitOp, outputs);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
|
||||||
|
SmallVector<Value> runtimeOutputs;
|
||||||
|
runtimeOutputs.reserve(resultTypes.size());
|
||||||
|
int64_t runtimeOffset = 0;
|
||||||
|
for (int64_t sliceSize : sliceSizes) {
|
||||||
|
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
|
||||||
|
runtimeOffset += sliceSize;
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
||||||
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOp(splitOp, computeOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@@ -0,0 +1,265 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||||
|
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isDirectConstantValue(Value value) {
|
||||||
|
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||||
|
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||||
|
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (batchOp.getLaneCount() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||||
|
|
||||||
|
auto loc = batchOp.getLoc();
|
||||||
|
rewriter.setInsertionPoint(batchOp);
|
||||||
|
auto computeOp =
|
||||||
|
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||||
|
computeOp.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||||
|
|
||||||
|
Block& templateBlock = batchOp.getBody().front();
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||||
|
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||||
|
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||||
|
blockArgTypes.push_back(arg.getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* newBlock =
|
||||||
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||||
|
mapper.map(oldArg, newArg);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
for (Operation& op : templateBlock)
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||||
|
rewriter.eraseOp(batchOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||||
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||||
|
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||||
|
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||||
|
bool needsRewrite = false;
|
||||||
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (inputIdx >= oldBlock.getNumArguments())
|
||||||
|
continue;
|
||||||
|
if (!isWeightLikeComputeOperand(input))
|
||||||
|
continue;
|
||||||
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||||
|
continue;
|
||||||
|
promoteInput[inputIdx] = true;
|
||||||
|
needsRewrite = true;
|
||||||
|
}
|
||||||
|
if (!needsRewrite)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(compute);
|
||||||
|
|
||||||
|
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
SmallVector<Value> newInputs;
|
||||||
|
SmallVector<Type> newInputTypes;
|
||||||
|
SmallVector<Location> newInputLocs;
|
||||||
|
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||||
|
newInputs.reserve(compute.getInputs().size());
|
||||||
|
newInputTypes.reserve(compute.getInputs().size());
|
||||||
|
newInputLocs.reserve(compute.getInputs().size());
|
||||||
|
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (promoteInput[inputIdx]) {
|
||||||
|
newWeights.push_back(input);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newInputs.push_back(input);
|
||||||
|
newInputTypes.push_back(input.getType());
|
||||||
|
newInputLocs.push_back(input.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newCompute =
|
||||||
|
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||||
|
auto* newBlock =
|
||||||
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||||
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
|
IRRewriter bodyRewriter(rewriter.getContext());
|
||||||
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
size_t newInputIdx = 0;
|
||||||
|
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||||
|
if (!promoteInput[oldInputIdx]) {
|
||||||
|
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||||
|
if (failed(clonedValue))
|
||||||
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||||
|
mapper.map(oldArg, *clonedValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : oldBlock.without_terminator())
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||||
|
SmallVector<Value> newYieldOperands;
|
||||||
|
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||||
|
for (Value operand : oldYield.getOutputs()) {
|
||||||
|
auto mapped = mapper.lookupOrNull(operand);
|
||||||
|
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||||
|
|
||||||
|
rewriter.replaceOp(compute, newCompute.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||||
|
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||||
|
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||||
|
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||||
|
bool needsRewrite = false;
|
||||||
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (inputIdx >= oldBlock.getNumArguments())
|
||||||
|
continue;
|
||||||
|
if (!isWeightLikeComputeOperand(input))
|
||||||
|
continue;
|
||||||
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||||
|
continue;
|
||||||
|
promoteInput[inputIdx] = true;
|
||||||
|
needsRewrite = true;
|
||||||
|
}
|
||||||
|
if (!needsRewrite)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(compute);
|
||||||
|
|
||||||
|
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
SmallVector<Value> newInputs;
|
||||||
|
SmallVector<Type> newInputTypes;
|
||||||
|
SmallVector<Location> newInputLocs;
|
||||||
|
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||||
|
newInputs.reserve(compute.getInputs().size());
|
||||||
|
newInputTypes.reserve(compute.getInputs().size());
|
||||||
|
newInputLocs.reserve(compute.getInputs().size());
|
||||||
|
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (promoteInput[inputIdx]) {
|
||||||
|
newWeights.push_back(input);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newInputs.push_back(input);
|
||||||
|
newInputTypes.push_back(input.getType());
|
||||||
|
newInputLocs.push_back(input.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newCompute =
|
||||||
|
spatial::SpatComputeBatch::create(rewriter,
|
||||||
|
compute.getLoc(),
|
||||||
|
compute.getResultTypes(),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||||
|
newWeights,
|
||||||
|
newInputs);
|
||||||
|
auto* newBlock =
|
||||||
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||||
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
|
IRRewriter bodyRewriter(rewriter.getContext());
|
||||||
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
size_t newInputIdx = 0;
|
||||||
|
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||||
|
if (!promoteInput[oldInputIdx]) {
|
||||||
|
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||||
|
if (failed(clonedValue))
|
||||||
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||||
|
mapper.map(oldArg, *clonedValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : oldBlock.without_terminator())
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||||
|
SmallVector<Value> newYieldOperands;
|
||||||
|
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||||
|
for (Value operand : oldYield.getOutputs()) {
|
||||||
|
auto mapped = mapper.lookupOrNull(operand);
|
||||||
|
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||||
|
|
||||||
|
rewriter.replaceOp(compute, newCompute.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||||
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
|
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||||
|
markWeightAlways(constantOp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||||
|
patterns.add<onnxToArithConstant>(ctx);
|
||||||
|
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||||
|
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||||
|
patterns.add<matMulAddToGemm>(ctx);
|
||||||
|
patterns.add<matMulToGemm>(ctx);
|
||||||
|
patterns.add<removeFlattenSameShape>(ctx);
|
||||||
|
populateMatMulRewritePatterns(patterns, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,218 @@
|
|||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir::pim;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||||
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
|
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||||
|
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
||||||
|
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
||||||
|
return coreIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||||
|
IRMapping& mapper,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
||||||
|
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||||
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||||
|
|
||||||
|
pim::PimSendTensorBatchOp::create(rewriter,
|
||||||
|
sendTensorBatchOp.getLoc(),
|
||||||
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||||
|
IRMapping& mapper,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
||||||
|
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
||||||
|
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||||
|
|
||||||
|
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||||
|
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
||||||
|
receiveTensorBatchOp.getLoc(),
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||||
|
if (computeBatchOp.getNumResults() != 0)
|
||||||
|
return computeBatchOp.emitOpError(
|
||||||
|
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
||||||
|
|
||||||
|
Location loc = computeBatchOp.getLoc();
|
||||||
|
Block& oldBlock = computeBatchOp.getBody().front();
|
||||||
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||||
|
if (oldYield.getNumOperands() != 0)
|
||||||
|
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
||||||
|
|
||||||
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||||
|
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||||
|
SmallVector<Value> batchInputs;
|
||||||
|
if (!computeBatchOp.getInputs().empty())
|
||||||
|
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||||
|
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||||
|
ValueRange(batchWeights),
|
||||||
|
ValueRange(batchInputs));
|
||||||
|
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||||
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
for (BlockArgument arg : oldBlock.getArguments()) {
|
||||||
|
blockArgTypes.push_back(arg.getType());
|
||||||
|
blockArgLocs.push_back(arg.getLoc());
|
||||||
|
}
|
||||||
|
Block* newBlock =
|
||||||
|
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
|
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
||||||
|
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
newArg,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(oldArg, copied);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
||||||
|
return mapped;
|
||||||
|
|
||||||
|
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
capturedTensor,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(capturedTensor, copied);
|
||||||
|
return copied;
|
||||||
|
};
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
for (Operation& op : oldBlock) {
|
||||||
|
if (isa<spatial::SpatYieldOp>(op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||||
|
pim::PimSendBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
mapper.lookup(sendBatchOp.getInput()),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||||
|
sendBatchOp.getTargetCoreIdsAttr());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||||
|
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||||
|
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||||
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||||
|
receiveBatchOp.getSourceCoreIdsAttr())
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(receiveBatchOp.getOutput(), received);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||||
|
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||||
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||||
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
auto clonedTensor = cloned->getResult(0);
|
||||||
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
clonedTensor,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(toTensorOp.getResult(), copied);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Value operand : op.getOperands()) {
|
||||||
|
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
materializeCapturedTensor(operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
PimHaltOp::create(rewriter, loc);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -4,7 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
|||||||
|
|
||||||
add_pim_library(OMSpatialToPim
|
add_pim_library(OMSpatialToPim
|
||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
|
BatchCoreLoweringPatterns.cpp
|
||||||
|
ChannelLoweringPatterns.cpp
|
||||||
|
Cleanup.cpp
|
||||||
Common.cpp
|
Common.cpp
|
||||||
|
ComputeLikeRegionUtils.cpp
|
||||||
|
CoreLoweringPatterns.cpp
|
||||||
|
GlobalTensorMaterialization.cpp
|
||||||
|
PhaseVerification.cpp
|
||||||
|
ReturnPathNormalization.cpp
|
||||||
|
TensorPackingPatterns.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,136 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||||
|
|
||||||
|
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||||
|
pim::PimSendOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
op.getInput(),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||||
|
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
|
||||||
|
if (op->use_empty()) {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
auto outputType = cast<ShapedType>(op.getResult().getType());
|
||||||
|
Value outputBuffer =
|
||||||
|
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||||
|
Value received = pim::PimReceiveOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
op.getResult().getType(),
|
||||||
|
outputBuffer,
|
||||||
|
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||||
|
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||||
|
.getOutput();
|
||||||
|
rewriter.replaceOp(op, received);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||||
|
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||||
|
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||||
|
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||||
|
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||||
|
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||||
|
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||||
|
Value outputBuffer =
|
||||||
|
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||||
|
Value received =
|
||||||
|
pim::PimReceiveTensorOp::create(
|
||||||
|
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||||
|
.getOutput();
|
||||||
|
rewriter.replaceOp(op, received);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
|
||||||
|
auto inputType = cast<RankedTensorType>(op.getInput().getType());
|
||||||
|
SmallVector<Value> replacements;
|
||||||
|
replacements.reserve(op.getNumResults());
|
||||||
|
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
|
||||||
|
auto outputType = cast<RankedTensorType>(output.getType());
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||||
|
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
replacements.push_back(
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
|
||||||
|
.getResult());
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, replacements);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
|
||||||
|
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||||
|
Value outputBuffer =
|
||||||
|
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||||
|
Value concatenated =
|
||||||
|
pim::PimConcatOp::create(
|
||||||
|
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
|
||||||
|
.getOutput();
|
||||||
|
rewriter.replaceOp(op, concatenated);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<ChannelSendLowering,
|
||||||
|
ChannelReceiveLowering,
|
||||||
|
ChannelSendTensorLowering,
|
||||||
|
ChannelReceiveTensorLowering,
|
||||||
|
ExtractRowsLowering,
|
||||||
|
ConcatLowering>(patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
||||||
|
while (!pendingOps.empty()) {
|
||||||
|
bool erasedAnyOp = false;
|
||||||
|
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
||||||
|
Operation* opToRemove = *it;
|
||||||
|
if (!opToRemove->use_empty()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(opToRemove);
|
||||||
|
it = pendingOps.erase(it);
|
||||||
|
erasedAnyOp = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (erasedAnyOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (Operation* opToRemove : pendingOps) {
|
||||||
|
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
||||||
|
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
||||||
|
for (Operation* user : opToRemove->getUsers()) {
|
||||||
|
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
||||||
|
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
||||||
|
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -7,23 +7,12 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
|
||||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
|
||||||
assert(attr && "required precomputed channel attr is missing");
|
|
||||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||||
/*
|
/*
|
||||||
EXAMPLE RUN:
|
EXAMPLE RUN:
|
||||||
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
|||||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
|
||||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
|
||||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
|
||||||
}
|
|
||||||
|
|
||||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
|
||||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
|
||||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
|
||||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
|
||||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Value createPimReceiveFromSpatialChannel(
|
|
||||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
|
||||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
|
||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
|
||||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
|
||||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
||||||
.getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
@@ -127,15 +85,26 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
|||||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
||||||
|
for (Operation* user : value.getUsers()) {
|
||||||
|
if (user->getBlock() != operation->getBlock())
|
||||||
|
return true;
|
||||||
|
if (operation->isBeforeInBlock(user))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
mlir::Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
auto resultType = result.getType();
|
auto resultType = result.getType();
|
||||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||||
|
|
||||||
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||||
auto validOperands =
|
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
||||||
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
||||||
|
});
|
||||||
auto bestOperand = validOperands.begin();
|
auto bestOperand = validOperands.begin();
|
||||||
|
|
||||||
if (bestOperand != validOperands.end())
|
if (bestOperand != validOperands.end())
|
||||||
|
|||||||
@@ -2,16 +2,10 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
|
||||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||||
* its static tensor input.
|
* its static tensor input.
|
||||||
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
|||||||
|
|
||||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
|
||||||
|
|
||||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
|
||||||
|
|
||||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
|
||||||
|
|
||||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
|
||||||
|
|
||||||
mlir::Value createPimReceiveFromSpatialChannel(
|
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||||
return std::distance(range.begin(), range.end());
|
return std::distance(range.begin(), range.end());
|
||||||
@@ -58,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
|
|||||||
|
|
||||||
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
inline mlir::tensor::EmptyOp
|
inline mlir::tensor::EmptyOp
|
||||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||||
|
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
|
||||||
|
if (inputCount == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
unsigned inputBegin = op->getNumOperands() - inputCount;
|
||||||
|
if (operandNumber < inputBegin)
|
||||||
|
return std::nullopt;
|
||||||
|
return operandNumber - inputBegin;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||||
|
return getInputIndex(owner, compute.getInputs().size());
|
||||||
|
|
||||||
|
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
||||||
|
return getInputIndex(owner, computeBatch.getInputs().size());
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||||
|
Operation* owner,
|
||||||
|
unsigned inputIndex,
|
||||||
|
Value replacement) {
|
||||||
|
Block& body = owner->getRegion(0).front();
|
||||||
|
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
||||||
|
|
||||||
|
rewriter.startOpModification(owner);
|
||||||
|
bodyArgument.replaceAllUsesWith(replacement);
|
||||||
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||||
|
compute.getInputsMutable().erase(inputIndex);
|
||||||
|
else
|
||||||
|
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||||
|
body.eraseArgument(inputIndex);
|
||||||
|
rewriter.finalizeOpModification(owner);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
|
||||||
|
|
||||||
|
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Operation* owner,
|
||||||
|
unsigned inputIndex,
|
||||||
|
mlir::Value replacement);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir::pim;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool isChannelUseChainOp(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp,
|
||||||
|
tensor::CollapseShapeOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CastOp,
|
||||||
|
tosa::ReshapeOp,
|
||||||
|
ONNXTransposeOp,
|
||||||
|
pim::PimTransposeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
if (mapping.lookupOrNull(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||||
|
|
||||||
|
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||||
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
|
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||||
|
return static_cast<int32_t>(fallbackCoreId++);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||||
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
|
bool requireReturnUse = true) {
|
||||||
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
if (requireReturnUse
|
||||||
|
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Block& block = computeOp.getBody().front();
|
||||||
|
if (block.getNumArguments() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Operation*> reverseChain;
|
||||||
|
Value currentValue = yieldOp.getOperands().front();
|
||||||
|
Value blockArg = block.getArgument(0);
|
||||||
|
|
||||||
|
while (currentValue != blockArg) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
||||||
|
return failure();
|
||||||
|
reverseChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||||
|
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
|
return false;
|
||||||
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||||
|
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
Block& block = computeOp.getBody().front();
|
||||||
|
if (block.getNumArguments() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(computeOp);
|
||||||
|
IRMapping mapping;
|
||||||
|
for (Operation& op : block.without_terminator()) {
|
||||||
|
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||||
|
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||||
|
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||||
|
state.operationsToRemove.push_back(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||||
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
|
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
auto& block = computeOp.getRegion().front();
|
||||||
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
|
||||||
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
||||||
|
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
||||||
|
if (!receiveOp || blockArg.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||||
|
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||||
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||||
|
Value received = PimReceiveOp::create(
|
||||||
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||||
|
.getOutput();
|
||||||
|
blockArg.replaceAllUsesWith(received);
|
||||||
|
markOpToRemove(state, receiveOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||||
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||||
|
|
||||||
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||||
|
if (result.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||||
|
ReturnPathLoweringResult returnPathResult =
|
||||||
|
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||||
|
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||||
|
return failure();
|
||||||
|
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto resultUses = result.getUses();
|
||||||
|
if (rangeLength(resultUses) == 1) {
|
||||||
|
OpOperand& resultUse = *resultUses.begin();
|
||||||
|
Operation* resultUser = resultUse.getOwner();
|
||||||
|
if (isa<spatial::SpatChannelSendOp>(resultUser))
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
|
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||||
|
|
||||||
|
SmallVector<Value> computeWeights;
|
||||||
|
if (!computeOp.getWeights().empty())
|
||||||
|
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
auto coreOp = PimCoreOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
ValueRange(computeWeights),
|
||||||
|
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||||
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||||
|
if (!blockArg.use_empty())
|
||||||
|
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
||||||
|
block.eraseArguments(0, block.getNumArguments());
|
||||||
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||||
|
Block* tempComputeBlock = new Block();
|
||||||
|
computeOp.getBody().push_back(tempComputeBlock);
|
||||||
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||||
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct CoreLoweringState {
|
||||||
|
size_t& nextCoreId;
|
||||||
|
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||||
|
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||||
|
};
|
||||||
|
|
||||||
|
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,390 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||||
|
std::string name = baseName.str();
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
|
||||||
|
name = (baseName + "_" + Twine(suffix++)).str();
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ModuleOp moduleOp,
|
||||||
|
StringRef baseName,
|
||||||
|
MemRefType type,
|
||||||
|
Attribute initialValue = {},
|
||||||
|
UnitAttr constant = {}) {
|
||||||
|
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
|
||||||
|
return memref::GlobalOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getStringAttr(symbolName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(type),
|
||||||
|
initialValue,
|
||||||
|
constant,
|
||||||
|
IntegerAttr {});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
|
||||||
|
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
|
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
|
||||||
|
|
||||||
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
|
if (BBArgValue.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
|
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
|
if (BBArgValue.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
|
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
{
|
||||||
|
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(extractSliceOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||||
|
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||||
|
Location loc = constantOp.getLoc();
|
||||||
|
|
||||||
|
if (hasWeightAlways(constantOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||||
|
if (isa<spatial::SpatCompute>(op))
|
||||||
|
return false;
|
||||||
|
if (isa<func::FuncOp>(op->getParentOp()))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||||
|
|
||||||
|
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||||
|
|
||||||
|
if (constRankedTensorType) {
|
||||||
|
mlir::MemRefType memRefType =
|
||||||
|
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||||
|
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||||
|
loc,
|
||||||
|
constantOp->getParentOfType<ModuleOp>(),
|
||||||
|
"const",
|
||||||
|
memRefType,
|
||||||
|
constantOp.getValueAttr(),
|
||||||
|
rewriter.getUnitAttr());
|
||||||
|
std::string argName = globalOp.getSymName().str();
|
||||||
|
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
|
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||||
|
spatComputeBatch.getOperation(),
|
||||||
|
BBArgIndex,
|
||||||
|
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
{
|
||||||
|
|
||||||
|
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||||
|
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||||
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||||
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
|
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
|
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||||
|
}
|
||||||
|
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||||
|
if (!mapSpatComputeToConst.contains(parent)) {
|
||||||
|
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||||
|
}
|
||||||
|
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||||
|
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||||
|
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||||
|
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||||
|
}
|
||||||
|
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (constantOp->use_empty())
|
||||||
|
rewriter.eraseOp(constantOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||||
|
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
|
||||||
|
|
||||||
|
if (funcOp.getArguments().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(funcOp.getArguments(),
|
||||||
|
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
|
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
|
||||||
|
if (arg.getUses().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(funcOp.getOperation());
|
||||||
|
|
||||||
|
assert(isa<mlir::RankedTensorType>(arg.getType()));
|
||||||
|
|
||||||
|
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
|
||||||
|
mlir::MemRefType memRefType =
|
||||||
|
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
||||||
|
|
||||||
|
std::string baseName = ("arg_" + Twine(index)).str();
|
||||||
|
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
|
||||||
|
rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
|
||||||
|
std::string argName = globalOp.getSymName().str();
|
||||||
|
|
||||||
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
|
auto argUser = argUses.getOwner();
|
||||||
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||||
|
}
|
||||||
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||||
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
|
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
rewriter.setInsertionPoint(argUser);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
rewriter.startOpModification(argUser);
|
||||||
|
argUses.set(toTensor);
|
||||||
|
rewriter.finalizeOpModification(argUser);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
if (op->getDialect()->getNamespace() != "spat")
|
||||||
|
return;
|
||||||
|
|
||||||
|
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
|
||||||
|
hasFailure = true;
|
||||||
|
});
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,587 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir::pim;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ReturnUseInfo {
|
||||||
|
size_t returnIndex;
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConcatReturnUseInfo {
|
||||||
|
size_t returnIndex;
|
||||||
|
SmallVector<int64_t> sliceOffsets;
|
||||||
|
SmallVector<int64_t> concatShape;
|
||||||
|
SmallVector<Operation*> concatChain;
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool isReturnHelperChainOp(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp,
|
||||||
|
tensor::CollapseShapeOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CastOp,
|
||||||
|
tosa::ReshapeOp,
|
||||||
|
ONNXTransposeOp,
|
||||||
|
pim::PimTransposeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||||
|
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||||
|
state.operationsToRemove.push_back(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||||
|
std::string name = baseName.str();
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
|
||||||
|
name = (baseName + "_" + Twine(suffix++)).str();
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
|
||||||
|
int64_t flatIndex = 0;
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
flatIndex *= shape[i];
|
||||||
|
flatIndex += indices[i];
|
||||||
|
}
|
||||||
|
return flatIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||||
|
indices[dim] = flatIndex % shape[dim];
|
||||||
|
flatIndex /= shape[dim];
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||||
|
SmallVectorImpl<Operation*>& helperChain) {
|
||||||
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
if (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Block& block = computeOp.getBody().front();
|
||||||
|
if (block.getNumArguments() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Operation*> reverseChain;
|
||||||
|
Value currentValue = yieldOp.getOperands().front();
|
||||||
|
Value blockArg = block.getArgument(0);
|
||||||
|
|
||||||
|
while (currentValue != blockArg) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp))
|
||||||
|
return failure();
|
||||||
|
reverseChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||||
|
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||||
|
auto uses = value.getUses();
|
||||||
|
if (rangeLength(uses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
Value currentValue = value;
|
||||||
|
Operation* currentUser = uses.begin()->getOwner();
|
||||||
|
|
||||||
|
while (isReturnHelperChainOp(currentUser)) {
|
||||||
|
helperChain.push_back(currentUser);
|
||||||
|
auto currentUses = currentUser->getResult(0).getUses();
|
||||||
|
if (rangeLength(currentUses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
currentValue = currentUser->getResult(0);
|
||||||
|
currentUser = currentUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<func::ReturnOp>(currentUser))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return ReturnUseInfo {
|
||||||
|
currentValue.getUses().begin()->getOperandNumber(),
|
||||||
|
std::move(helperChain),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||||
|
auto getConcatResult = [](Operation* op) -> Value {
|
||||||
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
|
return tensorConcat.getResult();
|
||||||
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||||
|
return spatialConcat.getOutput();
|
||||||
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||||
|
return pimConcat.getOutput();
|
||||||
|
return {};
|
||||||
|
};
|
||||||
|
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||||
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
|
return tensorConcat.getDim();
|
||||||
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||||
|
return spatialConcat.getAxis();
|
||||||
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||||
|
return pimConcat.getAxis();
|
||||||
|
return std::nullopt;
|
||||||
|
};
|
||||||
|
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||||
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
|
return tensorConcat.getOperands();
|
||||||
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||||
|
return spatialConcat.getInputs();
|
||||||
|
return cast<pim::PimConcatOp>(op).getInputs();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto uses = value.getUses();
|
||||||
|
if (rangeLength(uses) != 1
|
||||||
|
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto valueType = dyn_cast<ShapedType>(value.getType());
|
||||||
|
if (!valueType || !valueType.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
|
||||||
|
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
|
||||||
|
SmallVector<Operation*> concatChain;
|
||||||
|
Value currentValue = value;
|
||||||
|
Operation* currentUser = uses.begin()->getOwner();
|
||||||
|
|
||||||
|
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
|
||||||
|
concatChain.push_back(currentUser);
|
||||||
|
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
||||||
|
int64_t axis = *getConcatAxis(currentUser);
|
||||||
|
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
|
||||||
|
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
|
||||||
|
|
||||||
|
Value concatResult = getConcatResult(currentUser);
|
||||||
|
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
|
||||||
|
if (!concatType || !concatType.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
|
||||||
|
|
||||||
|
currentValue = concatResult;
|
||||||
|
auto currentUses = currentValue.getUses();
|
||||||
|
if (rangeLength(currentUses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
currentUser = currentUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> helperChain;
|
||||||
|
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
||||||
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
currentValue = helperCompute.getResult(0);
|
||||||
|
auto currentUses = currentValue.getUses();
|
||||||
|
if (rangeLength(currentUses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
currentUser = currentUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
while (isReturnHelperChainOp(currentUser)) {
|
||||||
|
helperChain.push_back(currentUser);
|
||||||
|
auto currentUses = currentUser->getResult(0).getUses();
|
||||||
|
if (rangeLength(currentUses) != 1)
|
||||||
|
return std::nullopt;
|
||||||
|
currentValue = currentUser->getResult(0);
|
||||||
|
currentUser = currentUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<func::ReturnOp>(currentUser))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return ConcatReturnUseInfo {
|
||||||
|
currentValue.getUses().begin()->getOperandNumber(),
|
||||||
|
std::move(sliceOffsets),
|
||||||
|
std::move(concatShape),
|
||||||
|
std::move(concatChain),
|
||||||
|
std::move(helperChain),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
|
||||||
|
ArrayRef<int64_t> sourceShape,
|
||||||
|
ArrayRef<Operation*> helperChain,
|
||||||
|
SmallVectorImpl<int64_t>& mappedIndices) {
|
||||||
|
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
|
||||||
|
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
|
||||||
|
|
||||||
|
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
|
||||||
|
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
|
||||||
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (Operation* op : helperChain) {
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||||
|
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||||
|
};
|
||||||
|
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||||
|
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> nextIndices;
|
||||||
|
nextIndices.reserve(currentIndices.size());
|
||||||
|
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
|
||||||
|
extractSliceOp.getStaticOffsets(),
|
||||||
|
extractSliceOp.getStaticSizes(),
|
||||||
|
extractSliceOp.getStaticStrides())) {
|
||||||
|
if (stride != 1 || index < offset || index >= offset + size)
|
||||||
|
return failure();
|
||||||
|
nextIndices.push_back(index - offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
currentIndices = std::move(nextIndices);
|
||||||
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
||||||
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||||
|
SmallVector<int64_t> nextShape(currentShape.size());
|
||||||
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
||||||
|
int64_t sourceIndex = attr.getInt();
|
||||||
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||||
|
nextShape[destIndex] = currentShape[sourceIndex];
|
||||||
|
}
|
||||||
|
currentIndices = std::move(nextIndices);
|
||||||
|
currentShape = std::move(nextShape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
|
||||||
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||||
|
SmallVector<int64_t> nextShape(currentShape.size());
|
||||||
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
|
||||||
|
int64_t sourceIndex = attr.getInt();
|
||||||
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||||
|
nextShape[destIndex] = currentShape[sourceIndex];
|
||||||
|
}
|
||||||
|
currentIndices = std::move(nextIndices);
|
||||||
|
currentShape = std::move(nextShape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
|
||||||
|
if (failed(reshapeToResultShape(op)))
|
||||||
|
return failure();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
if (mapping.lookupOrNull(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(sourceValue, sourceValue);
|
||||||
|
clonedValue = sourceValue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||||
|
for (Operation* op : helperChain) {
|
||||||
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
clonedValue = clonedOp->getResult(0);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value emitHostCopy(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Value outputTensor,
|
||||||
|
Value sourceValue,
|
||||||
|
int32_t hostTargetOffset,
|
||||||
|
int32_t deviceSourceOffset,
|
||||||
|
int32_t sizeInBytes) {
|
||||||
|
return PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputTensor.getType(),
|
||||||
|
outputTensor,
|
||||||
|
sourceValue,
|
||||||
|
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||||
|
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||||
|
IRRewriter& rewriter,
|
||||||
|
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||||
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||||
|
Value currentReturnValue = returnValue;
|
||||||
|
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
|
||||||
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||||
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||||
|
outputTensors.push_back(
|
||||||
|
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto outRankedTensorType = llvm::dyn_cast<RankedTensorType>(currentReturnValue.getType());
|
||||||
|
auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||||
|
|
||||||
|
std::string outputBaseName = ("output_" + Twine(index)).str();
|
||||||
|
std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType<ModuleOp>(), outputBaseName);
|
||||||
|
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
returnOp.getLoc(),
|
||||||
|
rewriter.getStringAttr(outputName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{});
|
||||||
|
outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
return toTensor.getResult();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||||
|
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||||
|
Location loc = computeOp->getLoc();
|
||||||
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||||
|
|
||||||
|
if (auto returnUse = analyzeReturnUse(result)) {
|
||||||
|
Value storedValue = yieldValue;
|
||||||
|
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
||||||
|
for (Operation* op : returnUse->helperChain)
|
||||||
|
markOpToRemove(state, op);
|
||||||
|
|
||||||
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||||
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
|
if (auto storedOp = storedValue.getDefiningOp())
|
||||||
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
|
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||||
|
emitHostCopy(
|
||||||
|
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultUses = result.getUses();
|
||||||
|
if (rangeLength(resultUses) == 1) {
|
||||||
|
OpOperand& resultUse = *resultUses.begin();
|
||||||
|
Operation* resultUser = resultUse.getOwner();
|
||||||
|
|
||||||
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
|
emitHostCopy(
|
||||||
|
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||||
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||||
|
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||||
|
markOpToRemove(state, concatOp);
|
||||||
|
|
||||||
|
if (concatReturnUse->helperChain.empty()) {
|
||||||
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||||
|
emitHostCopy(rewriter,
|
||||||
|
loc,
|
||||||
|
outputTensor,
|
||||||
|
yieldValue,
|
||||||
|
static_cast<int32_t>(flatOffset * elementSize),
|
||||||
|
0,
|
||||||
|
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||||
|
if (!storedType) {
|
||||||
|
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||||
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||||
|
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
||||||
|
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
||||||
|
|
||||||
|
SmallVector<int64_t> destinationIndices;
|
||||||
|
if (failed(mapIndicesThroughHelperChain(
|
||||||
|
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||||
|
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> extractOffsets;
|
||||||
|
SmallVector<OpFoldResult> extractSizes;
|
||||||
|
SmallVector<OpFoldResult> extractStrides;
|
||||||
|
extractOffsets.reserve(storedType.getRank());
|
||||||
|
extractSizes.reserve(storedType.getRank());
|
||||||
|
extractStrides.reserve(storedType.getRank());
|
||||||
|
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
|
||||||
|
extractOffsets.push_back(rewriter.getIndexAttr(idx));
|
||||||
|
extractSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalarTensorType =
|
||||||
|
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||||
|
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||||
|
rewriter.setInsertionPointAfter(elementSlice);
|
||||||
|
|
||||||
|
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||||
|
outputTensor = emitHostCopy(rewriter,
|
||||||
|
loc,
|
||||||
|
outputTensor,
|
||||||
|
elementSlice.getResult(),
|
||||||
|
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||||
|
0,
|
||||||
|
static_cast<int32_t>(elementSize));
|
||||||
|
}
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReturnPathLoweringResult::NotReturnPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||||
|
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||||
|
if (!op)
|
||||||
|
return;
|
||||||
|
|
||||||
|
bool isExclusivelyOwnedByReturnChain = op->use_empty();
|
||||||
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||||
|
Operation* onlyUser = *op->getUsers().begin();
|
||||||
|
isExclusivelyOwnedByReturnChain =
|
||||||
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||||
|
|| isReturnHelperChainOp(onlyUser);
|
||||||
|
}
|
||||||
|
if (!isExclusivelyOwnedByReturnChain)
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (isReturnHelperChainOp(op)) {
|
||||||
|
Value source = op->getOperand(0);
|
||||||
|
markOpToRemove(state, op);
|
||||||
|
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
|
markOpToRemove(state, computeOp);
|
||||||
|
if (!computeOp.getInputs().empty())
|
||||||
|
for (Value input : computeOp.getInputs())
|
||||||
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
|
markOpToRemove(state, concatOp);
|
||||||
|
for (Value operand : concatOp.getOperands())
|
||||||
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||||
|
markOpToRemove(state, concatOp);
|
||||||
|
for (Value operand : concatOp.getInputs())
|
||||||
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||||
|
markOpToRemove(state, concatOp);
|
||||||
|
for (Value operand : concatOp.getInputs())
|
||||||
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
|
auto loc = returnOp.getLoc();
|
||||||
|
for (auto it : llvm::enumerate(originalOperands)) {
|
||||||
|
size_t orderWithinReturn = it.index();
|
||||||
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||||
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||||
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||||
|
|
||||||
|
struct ReturnPathState {
|
||||||
|
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||||
|
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class ReturnPathLoweringResult {
|
||||||
|
Handled,
|
||||||
|
NotReturnPath,
|
||||||
|
Failure
|
||||||
|
};
|
||||||
|
|
||||||
|
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
||||||
|
mlir::IRRewriter& rewriter,
|
||||||
|
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
||||||
|
|
||||||
|
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||||
|
mlir::OpResult result,
|
||||||
|
mlir::Value yieldValue,
|
||||||
|
ReturnPathState& state,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
|
||||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
|
||||||
"spatial channel has precomputed source core id">;
|
|
||||||
|
|
||||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
|
||||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
|
||||||
"spatial channel has precomputed target core id">;
|
|
||||||
|
|
||||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
|
||||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
|
||||||
|
|
||||||
def onnxToPimTranspose : Pat<
|
def onnxToPimTranspose : Pat<
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
(PimTransposeOp $data, $perms,
|
(PimTransposeOp $data, $perms,
|
||||||
@@ -27,17 +16,11 @@ def onnxToPimTranspose : Pat<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVMM : Pat<
|
def spatToPimVMM : Pat<
|
||||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
||||||
(PimVMMOp $weightIndex, $vector,
|
(PimVMMOp $weightIndex, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimMVM : Pat<
|
|
||||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
|
||||||
(PimMVMOp $weightIndex, $vector,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatToPimVVAdd : Pat<
|
def spatToPimVVAdd : Pat<
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
(PimVVAddOp $a, $b,
|
(PimVVAddOp $a, $b,
|
||||||
@@ -80,18 +63,4 @@ def spatToPimVSoftmax : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatChannelSendToPimSend : Pat<
|
|
||||||
(SpatChannelSendOp $channel, $input),
|
|
||||||
(PimSendOp $input,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
|
||||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
|
||||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatChannelReceiveToPimReceive : Pat<
|
|
||||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
|
||||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
|
||||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
|
||||||
>;
|
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user