Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| a103ba328b | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 | |||
| 9f9e7c0892 | |||
| 03eab42971 | |||
| c15aba5d96 | |||
| 4821e8a55e | |||
| 88bb223bb1 | |||
| 623ee62a04 | |||
| ad56888b0b | |||
| f993840641 | |||
| 0c7db55a24 | |||
| 41de3cb150 | |||
| 4f3570520c | |||
| 628dc630a4 | |||
| 80a7298552 | |||
| 8ad504fcdf | |||
| e6f442c5d2 | |||
| f6b97b3813 | |||
| 26317ea7d0 | |||
| 909c4acfdd | |||
| feaff820e1 | |||
| 1e279ae9bb | |||
| 57f0cca8c0 | |||
| 5ff364027b | |||
| b1272d2283 | |||
| 58e6587697 | |||
| f6c8cc4aa5 | |||
| 566630b99a | |||
| 74931ad75b | |||
| f2fe147961 | |||
| 7bb58e80de | |||
| b2dc9c38b6 | |||
| 3cb6a1abc5 | |||
| 285773fa55 | |||
| bdacb9871d | |||
| 5b9bb0c191 | |||
| f789954ad7 | |||
| b6ba1e4fea | |||
| 717ad160cd | |||
| 905fa9f9a7 | |||
| 62b0a6e19d | |||
| b605585b1f | |||
| 08b0fcd850 | |||
| 9dccc2c701 | |||
| 5c839e62c1 | |||
| 15e8edb9c4 | |||
| 951baca106 | |||
| fc5bccb487 | |||
| 49dea15b95 | |||
| 5545b0f672 | |||
| cff929a083 | |||
| 89b3501aa8 | |||
| 412ca957f6 | |||
| 0f13269040 | |||
| dafc1d15b7 | |||
| 3fa140be25 | |||
| df703f0be9 | |||
| 9fa850c140 | |||
| 25ade1bd63 |
+12
@@ -1,5 +1,17 @@
|
|||||||
|
.zed
|
||||||
.idea
|
.idea
|
||||||
**/.vscode
|
**/.vscode
|
||||||
|
|
||||||
.claude
|
.claude
|
||||||
|
.codex
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
|
|
||||||
|
CMakeUserPresets.json
|
||||||
|
|
||||||
build
|
build
|
||||||
|
build_release
|
||||||
|
cmake-build-debug
|
||||||
|
cmake-build-release
|
||||||
|
compile.sh
|
||||||
|
|
||||||
|
**/__*
|
||||||
|
|||||||
+1
-1
@@ -3,4 +3,4 @@
|
|||||||
url = https://github.com/onnx/onnx-mlir.git
|
url = https://github.com/onnx/onnx-mlir.git
|
||||||
[submodule "backend-simulators/pim/pimsim-nn"]
|
[submodule "backend-simulators/pim/pimsim-nn"]
|
||||||
path = backend-simulators/pim/pimsim-nn
|
path = backend-simulators/pim/pimsim-nn
|
||||||
url = https://github.com/wangxy-2000/pimsim-nn.git
|
url = https://github.com/HEAPLab/pimsim-nn.git
|
||||||
|
|||||||
+80
-12
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
|||||||
|
|
||||||
project(raptor)
|
project(raptor)
|
||||||
|
|
||||||
# Add symlink to PIM as accelerator in onnx-mlir
|
# Materialize a CMake shim directory
|
||||||
function(raptor_ensure_symlink link_path target_path)
|
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||||
|
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||||
|
|
||||||
if(NOT EXISTS "${link_parent}")
|
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
message(FATAL_ERROR
|
||||||
|
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||||
|
" ${real_external_source_dir}"
|
||||||
|
)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (IS_SYMLINK "${shim_dir}")
|
||||||
|
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||||
|
file(REMOVE "${shim_dir}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||||
|
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
file(MAKE_DIRECTORY "${shim_dir}")
|
||||||
|
|
||||||
|
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||||
|
set(shim_contents
|
||||||
|
"get_filename_component(raptor_external_source_dir
|
||||||
|
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||||
|
REALPATH
|
||||||
|
)
|
||||||
|
add_subdirectory(
|
||||||
|
\"\${raptor_external_source_dir}\"
|
||||||
|
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||||
|
)
|
||||||
|
if (DEFINED PIM_ENABLED)
|
||||||
|
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||||
|
endif ()
|
||||||
|
"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (EXISTS "${shim_file}")
|
||||||
|
file(READ "${shim_file}" old_contents)
|
||||||
|
else ()
|
||||||
|
set(old_contents "")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (NOT old_contents STREQUAL shim_contents)
|
||||||
|
file(WRITE "${shim_file}" "${shim_contents}")
|
||||||
|
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||||
|
else ()
|
||||||
|
message(STATUS "CMake shim already up to date for ${description}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
# Mirror the external tree's first-level entries into the shim directory
|
||||||
|
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||||
|
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||||
|
|
||||||
|
foreach (child IN LISTS children)
|
||||||
|
if (child STREQUAL "CMakeLists.txt")
|
||||||
|
continue()
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
set(real_child "${real_external_source_dir}/${child}")
|
||||||
|
set(shim_child "${shim_dir}/${child}")
|
||||||
|
|
||||||
|
if (IS_SYMLINK "${shim_child}")
|
||||||
|
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||||
|
if (existing_link_target STREQUAL real_child)
|
||||||
|
continue()
|
||||||
|
endif ()
|
||||||
|
file(REMOVE_RECURSE "${shim_child}")
|
||||||
|
elseif (EXISTS "${shim_child}")
|
||||||
|
# Do not delete real files/directories. This protects the generated shim.
|
||||||
|
continue()
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if(NOT EXISTS "${link_path}")
|
|
||||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
|
||||||
file(CREATE_LINK
|
file(CREATE_LINK
|
||||||
"${target_path}"
|
"${real_child}"
|
||||||
"${link_path}"
|
"${shim_child}"
|
||||||
SYMBOLIC
|
SYMBOLIC
|
||||||
)
|
)
|
||||||
endif()
|
endforeach ()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
raptor_ensure_symlink(
|
raptor_write_external_cmake_shim(
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||||
|
"PIM accelerator"
|
||||||
)
|
)
|
||||||
raptor_ensure_symlink(
|
|
||||||
|
raptor_write_external_cmake_shim(
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||||
|
"PIM accelerator tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patch onnx-mlir sources for PIM accelerator support.
|
# Patch onnx-mlir sources for PIM accelerator support.
|
||||||
|
|||||||
@@ -1,5 +1,206 @@
|
|||||||
# Raptor
|
# Raptor
|
||||||
|
|
||||||
|
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
|
||||||
|
targeting in-memory computing / processing-in-memory (PIM) architectures.
|
||||||
|
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
|
||||||
|
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
PIM architectures perform most of the computation directly in memory.
|
||||||
|
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
|
||||||
|
- a shared host memory,
|
||||||
|
- a number of cores that do most of the computation directly in their memory
|
||||||
|
(vector ops, vmm/mvm on ReRAM crossbars),
|
||||||
|
- no branching instructions (branchless architecture) and no hardware loop
|
||||||
|
support — any repeated work (e.g. convolutions) must be unrolled into
|
||||||
|
explicit per-iteration instructions.
|
||||||
|
|
||||||
|
Because of this, the amount of emitted instructions explodes quickly and the
|
||||||
|
compiler must optimize aggressively at every stage to keep compilation
|
||||||
|
tractable.
|
||||||
|
|
||||||
|
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
|
||||||
|
each carrying its own in-memory computing unit and crossbars. It will live in
|
||||||
|
a dedicated dialect (future work).
|
||||||
|
|
||||||
|
### Targets and simulators
|
||||||
|
|
||||||
|
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
|
||||||
|
**performance** estimates (latency, energy), but does not functionally execute
|
||||||
|
the JSON code it consumes. To validate the numerical correctness of the JSON
|
||||||
|
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
|
||||||
|
a Rust simulator we maintain in-tree at
|
||||||
|
`backend-simulators/pim/pim-simulator`.
|
||||||
|
|
||||||
|
## Compilation pipeline
|
||||||
|
|
||||||
|
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
|
||||||
|
When working on this codebase, most changes should stay confined to those
|
||||||
|
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
|
||||||
|
framework-level details).
|
||||||
|
|
||||||
|
High-level lowering flow:
|
||||||
|
|
||||||
|
```
|
||||||
|
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||||
|
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
|
||||||
|
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
|
||||||
|
operations are accelerated by storing a constant RHS matrix into a
|
||||||
|
crossbar. Crossbars cannot be re-programmed during execution, have a
|
||||||
|
limited fixed size, and there is a limited number of them per core.
|
||||||
|
Conversion patterns are split by op family under
|
||||||
|
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
|
||||||
|
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
|
||||||
|
Reshape, Resize, Split).
|
||||||
|
|
||||||
|
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||||
|
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
|
||||||
|
materializes PIM cores (`pim.core`), inter-core communication
|
||||||
|
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
|
||||||
|
|
||||||
|
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||||
|
A DCP-inspired heuristic (Dynamic Critical Path — see the original
|
||||||
|
scheduling paper by Kwok & Ahmad,
|
||||||
|
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
|
||||||
|
that coarsens the virtual node graph and decides how to group compute
|
||||||
|
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
|
||||||
|
heuristic with different assumptions from the paper (different cost
|
||||||
|
model, constraints from crossbar capacity / core resources, and a
|
||||||
|
windowed coarsening loop instead of full-graph reprioritization). The
|
||||||
|
`dcp-critical-window-size` option controls how many lowest-slack virtual
|
||||||
|
nodes each coarsening iteration considers (0 = legacy full-graph
|
||||||
|
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
|
||||||
|
`MergeComputeNodesPass.cpp`.
|
||||||
|
|
||||||
|
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
||||||
|
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
|
||||||
|
standard MLIR `BufferizableOpInterface` machinery
|
||||||
|
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
||||||
|
|
||||||
|
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
|
||||||
|
Conservatively reuses same-typed local memref allocations inside PIM cores
|
||||||
|
after bufferization and before code generation.
|
||||||
|
|
||||||
|
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||||
|
- `HostConstantFolding` — folds host-side constants.
|
||||||
|
- `MaterializeHostConstantsPass` — materializes the remaining host
|
||||||
|
constants for emission.
|
||||||
|
- `VerificationPass` — checks invariants before emission.
|
||||||
|
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
|
||||||
|
and `pim-simulator`.
|
||||||
|
|
||||||
|
Supporting pieces:
|
||||||
|
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
|
||||||
|
core count, DCP window, experimental conv impl, concat error handling, …)
|
||||||
|
and `PimCodeGen` entry points.
|
||||||
|
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
||||||
|
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
||||||
|
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
||||||
|
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
||||||
|
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
||||||
|
|
||||||
|
## Key compiler options
|
||||||
|
|
||||||
|
Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||||
|
|
||||||
|
- `--maccel=PIM` — select the PIM accelerator.
|
||||||
|
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
|
||||||
|
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
|
||||||
|
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
|
||||||
|
run only the codegen tail.
|
||||||
|
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||||
|
per-core count.
|
||||||
|
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||||
|
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||||
|
merge-compute-nodes pass (default: `peft`).
|
||||||
|
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||||
|
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||||
|
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
Functional validation lives in `validation/` and drives the Rust
|
||||||
|
`pim-simulator` to compare Raptor's output against a reference.
|
||||||
|
|
||||||
|
Per-operation validation (from `validation/`):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--core-count 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--operations-dir ./networks/yolo11n/depth_04 \
|
||||||
|
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
Each validation run writes debugging artifacts into the benchmark's workspace
|
||||||
|
directory (for example `validation/operations/gemm/small/`):
|
||||||
|
- `inputs/` — generated input CSVs used for the run.
|
||||||
|
- `outputs/` — reference outputs dumped by the native ONNX runner.
|
||||||
|
- `raptor/` — compiler artifacts:
|
||||||
|
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
|
||||||
|
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
|
||||||
|
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
|
||||||
|
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
|
||||||
|
`raptor/reports/` such as `dcp_merge_report.txt`,
|
||||||
|
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
|
||||||
|
- `runner/` — generated reference runner source, build tree, and shared library.
|
||||||
|
- `simulation/out.bin` — raw simulator output dump used for output comparison.
|
||||||
|
|
||||||
|
That means you usually do not need to rerun standalone `--EmitSpatial` or
|
||||||
|
`--EmitPim` commands while debugging validation failures: the per-pass dialect
|
||||||
|
dumps are already available under `raptor/dialects/`.
|
||||||
|
|
||||||
|
The validator does not currently expose a simulator tracing flag, but once a
|
||||||
|
validation has produced `raptor/pim/` you can rerun the simulator manually with
|
||||||
|
tracing enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend-simulators/pim/pim-simulator
|
||||||
|
cargo run --no-default-features --features tracing --release \
|
||||||
|
--package pim-simulator --bin pim-simulator -- \
|
||||||
|
-f /path/to/workspace/raptor/pim \
|
||||||
|
-o /path/to/workspace/simulation/out.bin \
|
||||||
|
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||||
|
```
|
||||||
|
|
||||||
|
With `--features tracing`, the simulator writes per-core traces as
|
||||||
|
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
|
||||||
|
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
|
||||||
|
and the model output shapes. If you need a clean slate before rerunning, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
validate.py --clean
|
||||||
|
```
|
||||||
|
|
||||||
|
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||||
|
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||||
|
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||||
|
`sigmoid`, `softmax`, `split`.
|
||||||
|
|
||||||
|
## Rebuilding
|
||||||
|
|
||||||
|
Release build (fast):
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
|
||||||
|
```
|
||||||
|
|
||||||
|
A slower debug build is also available — configure it the same way but with
|
||||||
|
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Protobuf
|
### Protobuf
|
||||||
|
|||||||
+2121
-8
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "pim-simulator"
|
name = "pim-simulator"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -13,8 +12,9 @@ name = "pimcore"
|
|||||||
path = "src/lib/pimcore.rs"
|
path = "src/lib/pimcore.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["tracing"]
|
default = []
|
||||||
tracing = []
|
tracing = []
|
||||||
|
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -27,3 +27,10 @@ hex = "0"
|
|||||||
paste = "1"
|
paste = "1"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
statrs = {version="0.16", optional=true}
|
||||||
|
comfy-table = {version="7.1", optional=true}
|
||||||
|
plotly = {version="0.8", optional=true}
|
||||||
|
rayon = "1.12.0"
|
||||||
|
faer = "0.24.0"
|
||||||
|
faer-traits = "0.24.0"
|
||||||
|
mimalloc = "0.1.50"
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
|
use mimalloc::MiMalloc;
|
||||||
|
|
||||||
|
#[global_allocator]
|
||||||
|
static GLOBAL: MiMalloc = MiMalloc;
|
||||||
|
|
||||||
use anyhow::{Context, Result, bail};
|
use anyhow::{Context, Result, bail};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use glob::glob;
|
use glob::glob;
|
||||||
|
use pimcore::binary_to_instruction::binary_to_executor;
|
||||||
use pimcore::cpu::crossbar::Crossbar;
|
use pimcore::cpu::crossbar::Crossbar;
|
||||||
use pimcore::json_to_instruction::json_to_executor;
|
use pimcore::json_to_instruction::json_to_executor;
|
||||||
use pimcore::memory_manager::CoreMemory;
|
use pimcore::memory_manager::CoreMemory;
|
||||||
use pimcore::tracing::TRACER;
|
use pimcore::tracing::TRACER;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs::{self, read_link};
|
use std::fs::{self, File, read_link};
|
||||||
use std::io::Write;
|
use std::io::{BufReader, Write};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
/// Program to simulate core execution configuration
|
/// Program to simulate core execution configuration
|
||||||
@@ -44,18 +50,24 @@ fn main() -> Result<()> {
|
|||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config_json = retrive_config(&args)?;
|
let config_json = retrive_config(&args)?;
|
||||||
let core_jsons = retrive_cores(&args)?;
|
let mut core_inputs = retrive_cores(&args)?;
|
||||||
let memory = retrive_memory(&args)?;
|
let memory = retrive_memory(&args)?;
|
||||||
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
||||||
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
||||||
let mut executor =
|
let mut executor = match &mut core_inputs {
|
||||||
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
|
CoreInputs::Json(core_jsons) => {
|
||||||
|
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
|
||||||
|
}
|
||||||
|
CoreInputs::Binary(core_bins) => {
|
||||||
|
binary_to_executor(config_json, core_bins.iter(), crossbars)?
|
||||||
|
}
|
||||||
|
};
|
||||||
set_memory(&mut executor, memory);
|
set_memory(&mut executor, memory);
|
||||||
TRACER
|
TRACER
|
||||||
.lock()
|
.lock()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.init(executor.cpu().num_core(), args.output.clone());
|
.init(executor.cpu().num_core(), args.output.clone());
|
||||||
executor.execute();
|
executor.execute()?;
|
||||||
dump_memory(executor, &args)?;
|
dump_memory(executor, &args)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -65,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
|
|||||||
args: &Args,
|
args: &Args,
|
||||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||||
) -> Vec<Vec<&'c Crossbar>> {
|
) -> Vec<Vec<&'c Crossbar>> {
|
||||||
let mut res = Vec::new();
|
let mut res = vec![Vec::new()];
|
||||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||||
|
|
||||||
if let Some(folder) = args.folder.as_ref() {
|
if let Some(folder) = args.folder.as_ref() {
|
||||||
@@ -140,8 +152,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
|||||||
}
|
}
|
||||||
|
|
||||||
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
||||||
let mut crossbar =
|
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||||
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
|
||||||
crossbar.execute_store(&bytes).unwrap();
|
crossbar.execute_store(&bytes).unwrap();
|
||||||
res.insert(
|
res.insert(
|
||||||
weight_file
|
weight_file
|
||||||
@@ -214,45 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
|
|||||||
Ok(memory_vector)
|
Ok(memory_vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
|
enum CoreInputs {
|
||||||
let mut core_jsons: Vec<Value> = Vec::new();
|
Json(Vec<BufReader<File>>),
|
||||||
if let Some(cores_override) = &args.cores {
|
Binary(Vec<Vec<u8>>),
|
||||||
for core in cores_override {
|
|
||||||
let content = fs::read_to_string(core)
|
|
||||||
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
|
|
||||||
let json: Value =
|
|
||||||
serde_json::from_str(&content).context("Failed to parse core json override")?;
|
|
||||||
core_jsons.push(json);
|
|
||||||
}
|
}
|
||||||
} else if let Some(folder) = args.folder.as_ref() {
|
|
||||||
let pattern = folder.join("core*.json");
|
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
|
||||||
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
|
if let Some(cores_override) = &args.cores {
|
||||||
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
|
let first_extension = cores_override
|
||||||
paths.sort_by_cached_key(|x| {
|
.first()
|
||||||
let mut x = x
|
.and_then(|path| path.extension())
|
||||||
|
.and_then(|ext| ext.to_str())
|
||||||
|
.unwrap_or_default();
|
||||||
|
if first_extension == "pim" {
|
||||||
|
let mut core_bins = Vec::with_capacity(cores_override.len());
|
||||||
|
for core in cores_override {
|
||||||
|
core_bins.push(
|
||||||
|
fs::read(core)
|
||||||
|
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Ok(CoreInputs::Binary(core_bins));
|
||||||
|
}
|
||||||
|
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
|
||||||
|
for core in cores_override {
|
||||||
|
let file = File::open(core)?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
core_jsons_reader.push(reader);
|
||||||
|
}
|
||||||
|
return Ok(CoreInputs::Json(core_jsons_reader));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(folder) = args.folder.as_ref() {
|
||||||
|
let binary_pattern = folder.join("core*.pim");
|
||||||
|
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
|
||||||
|
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||||
|
binary_paths.sort_by_cached_key(core_sort_key);
|
||||||
|
if !binary_paths.is_empty() {
|
||||||
|
let mut core_bins = Vec::with_capacity(binary_paths.len());
|
||||||
|
for path in binary_paths {
|
||||||
|
core_bins.push(
|
||||||
|
fs::read(&path)
|
||||||
|
.with_context(|| format!("Failed to read core file: {:?}", path))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Ok(CoreInputs::Binary(core_bins));
|
||||||
|
}
|
||||||
|
|
||||||
|
let json_pattern = folder.join("core*.json");
|
||||||
|
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
|
||||||
|
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||||
|
json_paths.sort_by_cached_key(core_sort_key);
|
||||||
|
|
||||||
|
if json_paths.is_empty() {
|
||||||
|
bail!("No core*.pim or core*.json files found in {:?}", folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
|
||||||
|
for path in json_paths {
|
||||||
|
let file = File::open(path)?;
|
||||||
|
let reader = BufReader::new(file);
|
||||||
|
core_json_reader.push(reader);
|
||||||
|
}
|
||||||
|
return Ok(CoreInputs::Json(core_json_reader));
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn core_sort_key(path: &PathBuf) -> i32 {
|
||||||
|
let mut stem = path
|
||||||
.file_stem()
|
.file_stem()
|
||||||
.expect("Extracting the stem")
|
.expect("Extracting the stem")
|
||||||
.to_str()
|
.to_str()
|
||||||
.expect("File not utf-8");
|
.expect("File not utf-8");
|
||||||
x = &x[5..];
|
stem = &stem[5..];
|
||||||
x.parse::<i32>().unwrap()
|
stem.parse::<i32>().unwrap()
|
||||||
});
|
|
||||||
|
|
||||||
if paths.is_empty() {
|
|
||||||
bail!("No core*.json files found in {:?}", folder);
|
|
||||||
}
|
|
||||||
for entry in paths {
|
|
||||||
let path = entry;
|
|
||||||
let content = fs::read_to_string(&path)
|
|
||||||
.with_context(|| format!("Failed to read core file: {:?}", path))?;
|
|
||||||
let json: Value = serde_json::from_str(&content)
|
|
||||||
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
|
|
||||||
core_jsons.push(json);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
|
||||||
}
|
|
||||||
Ok(core_jsons)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
|
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
|
||||||
|
|||||||
@@ -0,0 +1,497 @@
|
|||||||
|
use crate::{
|
||||||
|
CoreInstructionsBuilder, Executable,
|
||||||
|
cpu::{CPU, crossbar::Crossbar},
|
||||||
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
|
};
|
||||||
|
use anyhow::{Context, Result, bail, ensure};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
const MAGIC: &[u8; 4] = b"PIMB";
|
||||||
|
const VERSION: u32 = 1;
|
||||||
|
const HEADER_SIZE: usize = 12;
|
||||||
|
const RECORD_SIZE: usize = 20;
|
||||||
|
|
||||||
|
macro_rules! add_name {
|
||||||
|
($storage:ident, $opcode:literal, $name:literal) => {
|
||||||
|
$storage.insert($opcode, $name);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||||
|
let mut hash = HashMap::new();
|
||||||
|
add_name!(hash, 0, "nop");
|
||||||
|
add_name!(hash, 1, "sldi");
|
||||||
|
add_name!(hash, 2, "sld");
|
||||||
|
add_name!(hash, 3, "sadd");
|
||||||
|
add_name!(hash, 4, "ssub");
|
||||||
|
add_name!(hash, 5, "smul");
|
||||||
|
add_name!(hash, 6, "saddi");
|
||||||
|
add_name!(hash, 7, "smuli");
|
||||||
|
add_name!(hash, 8, "setbw");
|
||||||
|
add_name!(hash, 9, "mvmul");
|
||||||
|
add_name!(hash, 10, "vvadd");
|
||||||
|
add_name!(hash, 11, "vvsub");
|
||||||
|
add_name!(hash, 12, "vvmul");
|
||||||
|
add_name!(hash, 13, "vvdmul");
|
||||||
|
add_name!(hash, 14, "vvmax");
|
||||||
|
add_name!(hash, 15, "vvsll");
|
||||||
|
add_name!(hash, 16, "vvsra");
|
||||||
|
add_name!(hash, 17, "vavg");
|
||||||
|
add_name!(hash, 18, "vrelu");
|
||||||
|
add_name!(hash, 19, "vtanh");
|
||||||
|
add_name!(hash, 20, "vsigm");
|
||||||
|
add_name!(hash, 21, "vsoftmax");
|
||||||
|
add_name!(hash, 22, "vmv");
|
||||||
|
add_name!(hash, 23, "vrsu");
|
||||||
|
add_name!(hash, 24, "vrsl");
|
||||||
|
add_name!(hash, 25, "ld");
|
||||||
|
add_name!(hash, 26, "st");
|
||||||
|
add_name!(hash, 27, "lldi");
|
||||||
|
add_name!(hash, 28, "lmv");
|
||||||
|
add_name!(hash, 29, "send");
|
||||||
|
add_name!(hash, 30, "recv");
|
||||||
|
add_name!(hash, 31, "wait");
|
||||||
|
add_name!(hash, 32, "sync");
|
||||||
|
hash
|
||||||
|
});
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Default)]
|
||||||
|
struct InstructionRecord {
|
||||||
|
opcode: u8,
|
||||||
|
rd: u8,
|
||||||
|
r1: u8,
|
||||||
|
r2_or_imm: i32,
|
||||||
|
generic1: i32,
|
||||||
|
generic2: i32,
|
||||||
|
generic3: i32,
|
||||||
|
flags: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
|
||||||
|
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
|
||||||
|
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
|
||||||
|
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
|
||||||
|
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
|
||||||
|
|
||||||
|
let version = read_u32_le(bytes, 4);
|
||||||
|
ensure!(
|
||||||
|
version == VERSION,
|
||||||
|
"unsupported PIM binary version {version}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let instruction_count = read_u32_le(bytes, 8) as usize;
|
||||||
|
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
|
||||||
|
ensure!(
|
||||||
|
bytes.len() == expected_len,
|
||||||
|
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
|
||||||
|
bytes.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut records = Vec::with_capacity(instruction_count);
|
||||||
|
for index in 0..instruction_count {
|
||||||
|
let base = HEADER_SIZE + index * RECORD_SIZE;
|
||||||
|
records.push(InstructionRecord {
|
||||||
|
opcode: bytes[base],
|
||||||
|
rd: bytes[base + 1],
|
||||||
|
r1: bytes[base + 2],
|
||||||
|
flags: bytes[base + 3],
|
||||||
|
r2_or_imm: read_i32_le(bytes, base + 4),
|
||||||
|
generic1: read_i32_le(bytes, base + 8),
|
||||||
|
generic2: read_i32_le(bytes, base + 12),
|
||||||
|
generic3: read_i32_le(bytes, base + 16),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(records)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_record(
|
||||||
|
inst_builder: &mut InstructionsBuilder,
|
||||||
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
|
record: InstructionRecord,
|
||||||
|
) -> Result<()> {
|
||||||
|
let InstructionRecord {
|
||||||
|
opcode,
|
||||||
|
rd,
|
||||||
|
r1,
|
||||||
|
r2_or_imm,
|
||||||
|
generic1,
|
||||||
|
generic2,
|
||||||
|
generic3,
|
||||||
|
flags: _,
|
||||||
|
} = record;
|
||||||
|
|
||||||
|
match opcode {
|
||||||
|
0 => {}
|
||||||
|
1 => {
|
||||||
|
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
|
||||||
|
inst_builder.make_inst(sldi, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd_u8(rd)
|
||||||
|
.set_r1_u8(r1)
|
||||||
|
.set_offset_select(generic1)
|
||||||
|
.set_offset_value(generic2);
|
||||||
|
inst_builder.make_inst(sld, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||||
|
inst_builder.make_inst(sadd, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||||
|
inst_builder.make_inst(ssub, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
5 => {
|
||||||
|
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||||
|
inst_builder.make_inst(smul, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
6 => {
|
||||||
|
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||||
|
inst_builder.make_inst(saddi, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
7 => {
|
||||||
|
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||||
|
inst_builder.make_inst(smuli, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
8 => {
|
||||||
|
inst_data_builder.set_ibiw_obiw(generic1, generic2);
|
||||||
|
inst_builder.make_inst(setbw, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
9 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd_u8(rd)
|
||||||
|
.set_r1_u8(r1)
|
||||||
|
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
|
||||||
|
inst_builder.make_inst(mvmul, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
10 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvadd, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
11 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
12 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvmul, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
13 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
14 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvmax, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
15 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvsll, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
16 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vvsra, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
17 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
18 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vrelu, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
19 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vtanh, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
20 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vsigm, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
21 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
22 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vmv, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
23 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vrsu, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
24 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(vrsl, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
25 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(ld, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
26 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(st, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
27 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd_u8(rd)
|
||||||
|
.set_imm(r2_or_imm)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(lldi, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
28 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rdr1_u8(rd, r1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(lmv, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
29 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd_u8(rd)
|
||||||
|
.set_imm_core(r2_or_imm + 1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(send, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
30 => {
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd_u8(rd)
|
||||||
|
.set_imm_core(r2_or_imm + 1)
|
||||||
|
.set_imm_len(generic3)
|
||||||
|
.set_offset_select_value(generic1, generic2);
|
||||||
|
inst_builder.make_inst(recv, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
31 => {
|
||||||
|
inst_builder.make_inst(wait, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
32 => {
|
||||||
|
inst_builder.make_inst(sync, inst_data_builder.build());
|
||||||
|
}
|
||||||
|
_ => bail!("unsupported PIM binary opcode {opcode}"),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_to_instructions(
|
||||||
|
core_bytes: &[u8],
|
||||||
|
core_index: i32,
|
||||||
|
) -> Result<Vec<crate::instruction_set::Instruction>> {
|
||||||
|
let records = parse_binary_records(core_bytes)?;
|
||||||
|
let mut insts_builder = InstructionsBuilder::new();
|
||||||
|
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||||
|
inst_data_builder
|
||||||
|
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
|
||||||
|
.fix_core_indx();
|
||||||
|
|
||||||
|
for record in records {
|
||||||
|
let opcode = record.opcode;
|
||||||
|
let name = INSTRUCTIONS
|
||||||
|
.get(&(opcode as usize))
|
||||||
|
.copied()
|
||||||
|
.unwrap_or("<unknown>");
|
||||||
|
|
||||||
|
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(insts_builder.build())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_to_executor<'a, 'b>(
|
||||||
|
config: Value,
|
||||||
|
cores: impl Iterator<Item = &'b Vec<u8>>,
|
||||||
|
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||||
|
) -> Result<Executable<'a>> {
|
||||||
|
let core_cnt = config
|
||||||
|
.get("core_cnt")
|
||||||
|
.context("missing core_cnt in config")?
|
||||||
|
.as_i64()
|
||||||
|
.context("core_cnt is not an integer")? as i32;
|
||||||
|
|
||||||
|
let cpu = CPU::new(core_cnt, crossbars);
|
||||||
|
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||||
|
for (external_core_indx, core_bytes) in cores.enumerate() {
|
||||||
|
let core_indx = external_core_indx as i32 + 1;
|
||||||
|
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
||||||
|
core_insts_builder.set_core(core_indx, instructions);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Executable::new(cpu, core_insts_builder.build()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
functor_to_name,
|
||||||
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||||
|
json_to_instruction::json_isa::json_to_instruction,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
|
||||||
|
dst.push(record.opcode);
|
||||||
|
dst.push(record.rd);
|
||||||
|
dst.push(record.r1);
|
||||||
|
dst.push(record.flags);
|
||||||
|
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
|
||||||
|
dst.extend_from_slice(&record.generic1.to_le_bytes());
|
||||||
|
dst.extend_from_slice(&record.generic2.to_le_bytes());
|
||||||
|
dst.extend_from_slice(&record.generic3.to_le_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
|
||||||
|
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
|
||||||
|
blob.extend_from_slice(MAGIC);
|
||||||
|
blob.extend_from_slice(&VERSION.to_le_bytes());
|
||||||
|
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
|
||||||
|
for &record in records {
|
||||||
|
encode_record(record, &mut blob);
|
||||||
|
}
|
||||||
|
blob
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn json_and_binary_decoders_match_for_representative_ops() {
|
||||||
|
let json_program = [
|
||||||
|
r#"{"imm":64,"op":"sldi","rd":0}"#,
|
||||||
|
r#"{"imm":128,"op":"sldi","rd":1}"#,
|
||||||
|
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
|
||||||
|
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
|
||||||
|
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
|
||||||
|
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
|
||||||
|
];
|
||||||
|
|
||||||
|
let binary_program = binary_blob(&[
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 1,
|
||||||
|
rd: 0,
|
||||||
|
r2_or_imm: 64,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 1,
|
||||||
|
rd: 1,
|
||||||
|
r2_or_imm: 128,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 28,
|
||||||
|
rd: 0,
|
||||||
|
r1: 1,
|
||||||
|
generic3: 16,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 9,
|
||||||
|
rd: 0,
|
||||||
|
r1: 1,
|
||||||
|
r2_or_imm: 8,
|
||||||
|
generic2: 3,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 10,
|
||||||
|
rd: 0,
|
||||||
|
r1: 1,
|
||||||
|
r2_or_imm: 2,
|
||||||
|
generic3: 16,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
InstructionRecord {
|
||||||
|
opcode: 29,
|
||||||
|
rd: 0,
|
||||||
|
r2_or_imm: 2,
|
||||||
|
generic3: 16,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let mut json_builder = InstructionsBuilder::new();
|
||||||
|
let mut json_data_builder = InstructionDataBuilder::new();
|
||||||
|
json_data_builder.set_core_indx(1).fix_core_indx();
|
||||||
|
for inst in json_program {
|
||||||
|
let value = serde_json::from_str(inst).unwrap();
|
||||||
|
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
|
||||||
|
}
|
||||||
|
let json_instructions = json_builder.build();
|
||||||
|
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||||
|
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||||
|
assert_eq!(
|
||||||
|
functor_to_name(json_inst.functor as usize),
|
||||||
|
functor_to_name(binary_inst.functor as usize)
|
||||||
|
);
|
||||||
|
assert_eq!(json_inst.data, binary_inst.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
use paste::paste;
|
use paste::paste;
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Default)]
|
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||||
pub struct InstructionData {
|
pub struct InstructionData {
|
||||||
core_indx: i32,
|
core_indx: u16,
|
||||||
rd: i32,
|
rd: u8,
|
||||||
r1: i32,
|
r1: u8,
|
||||||
//r2 imm mbiw imm_core
|
//r2 imm mbiw imm_core
|
||||||
r2_or_imm: i32,
|
r2_or_imm: i32,
|
||||||
//offset_select imm_relu ibiw
|
//offset_select imm_relu ibiw
|
||||||
@@ -16,18 +17,30 @@ pub struct InstructionData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl InstructionData {
|
impl InstructionData {
|
||||||
pub fn core_indx(&self) -> i32 {
|
pub fn core_indx_u16(&self) -> u16 {
|
||||||
self.core_indx
|
self.core_indx
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rd(&self) -> i32 {
|
pub fn core_indx(&self) -> i32 {
|
||||||
|
i32::from(self.core_indx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rd_u8(&self) -> u8 {
|
||||||
self.rd
|
self.rd
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn r1(&self) -> i32 {
|
pub fn rd(&self) -> i32 {
|
||||||
|
i32::from(self.rd)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn r1_u8(&self) -> u8 {
|
||||||
self.r1
|
self.r1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn r1(&self) -> i32 {
|
||||||
|
i32::from(self.r1)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn r2(&self) -> i32 {
|
pub fn r2(&self) -> i32 {
|
||||||
self.r2_or_imm
|
self.r2_or_imm
|
||||||
}
|
}
|
||||||
@@ -49,26 +62,26 @@ impl InstructionData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
|
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
|
||||||
(self.core_indx, self.rd, self.r1)
|
(self.core_indx(), self.rd(), self.r1())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
|
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
|
||||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
|
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
|
||||||
(self.core_indx, self.rd, self.r2_or_imm)
|
(self.core_indx(), self.rd(), self.r2_or_imm)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
|
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
|
||||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
|
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
|
||||||
(
|
(
|
||||||
self.core_indx,
|
self.core_indx(),
|
||||||
self.rd,
|
self.rd(),
|
||||||
self.r1,
|
self.r1(),
|
||||||
self.r2_or_imm,
|
self.r2_or_imm,
|
||||||
self.generic3,
|
self.generic3,
|
||||||
self.generic1,
|
self.generic1,
|
||||||
@@ -78,9 +91,9 @@ impl InstructionData {
|
|||||||
|
|
||||||
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
|
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
|
||||||
(
|
(
|
||||||
self.core_indx,
|
self.core_indx(),
|
||||||
self.rd,
|
self.rd(),
|
||||||
self.r1,
|
self.r1(),
|
||||||
self.r2_or_imm,
|
self.r2_or_imm,
|
||||||
self.generic1,
|
self.generic1,
|
||||||
self.generic2,
|
self.generic2,
|
||||||
@@ -100,7 +113,7 @@ impl InstructionData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
|
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
|
||||||
(self.core_indx, self.r2_or_imm)
|
(self.core_indx(), self.r2_or_imm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
|
|||||||
common_getter_setter![imm_group];
|
common_getter_setter![imm_group];
|
||||||
common_getter_setter![imm_core];
|
common_getter_setter![imm_core];
|
||||||
|
|
||||||
|
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
|
||||||
|
self.set_core_indx(i32::from(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
|
||||||
|
self.set_rd(i32::from(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
|
||||||
|
self.set_r1(i32::from(val))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
core_indx: Fixer::Edit(0),
|
core_indx: Fixer::Edit(0),
|
||||||
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
|
|||||||
|
|
||||||
fn check_sanity(&self) {
|
fn check_sanity(&self) {
|
||||||
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
|
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
|
||||||
assert!(
|
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
|
||||||
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
|
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(&mut self) -> InstructionData {
|
pub fn build(&mut self) -> InstructionData {
|
||||||
self.check_sanity();
|
self.check_sanity();
|
||||||
let inst_data = InstructionData {
|
let inst_data = InstructionData {
|
||||||
core_indx: self.get_core_indx(),
|
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
|
||||||
rd: self.get_rd(),
|
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
|
||||||
r1: self.get_r1(),
|
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
|
||||||
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
|
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
|
||||||
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
|
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
|
||||||
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
|
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
|
||||||
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
|
|||||||
self.set_rd(rd).set_r1(r1).set_r2(r2)
|
self.set_rd(rd).set_r1(r1).set_r2(r2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
|
||||||
|
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
|
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
|
||||||
self.set_offset_select(offset_select)
|
self.set_offset_select(offset_select)
|
||||||
.set_offset_value(offset_value)
|
.set_offset_value(offset_value)
|
||||||
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
|
|||||||
self.set_rd(rd).set_r1(r1).set_imm(imm)
|
self.set_rd(rd).set_r1(r1).set_imm(imm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
|
||||||
|
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
|
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
|
||||||
self.set_rd(rd).set_r1(r1)
|
self.set_rd(rd).set_r1(r1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
|
||||||
|
self.set_rd_u8(rd).set_r1_u8(r1)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
|
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
|
||||||
self.set_rd(rd).set_imm(imm)
|
self.set_rd(rd).set_imm(imm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
|
||||||
|
self.set_rd_u8(rd).set_imm(imm)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
|
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
|
||||||
self.set_ibiw(ibiw).set_obiw(obiw)
|
self.set_ibiw(ibiw).set_obiw(obiw)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
cpu::{CPU, crossbar}, instruction_set::{
|
cpu::{CPU, crossbar},
|
||||||
|
instruction_set::{
|
||||||
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
|
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
|
||||||
helper::add_all,
|
helper::add_all,
|
||||||
}, memory_manager::{
|
},
|
||||||
|
memory_manager::{
|
||||||
MemoryStorable,
|
MemoryStorable,
|
||||||
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||||
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
|
},
|
||||||
|
tracing::TRACER,
|
||||||
|
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
|
||||||
};
|
};
|
||||||
use aligned_vec::{AVec, ConstAlign};
|
use aligned_vec::{AVec, ConstAlign};
|
||||||
use anyhow::{Context, Result, ensure};
|
use anyhow::{Context, Result, ensure};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
use paste::paste;
|
use paste::paste;
|
||||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
|
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
|
||||||
@@ -30,7 +35,7 @@ macro_rules! add_name_simd {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||||
let mut hash = HashMap::new();
|
let mut hash = HashMap::new();
|
||||||
add_name!(hash, sldi);
|
add_name!(hash, sldi);
|
||||||
add_name!(hash, sld);
|
add_name!(hash, sld);
|
||||||
@@ -76,8 +81,8 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
|||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
#[inline(never)]
|
||||||
{
|
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
let core = cores.core(core_indx);
|
let core = cores.core(core_indx);
|
||||||
@@ -86,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_sld(cores, data);
|
TRACER.lock().unwrap().pre_sld(cores, data);
|
||||||
let (core_indx, rd, r1) = data.get_core_rd_r1();
|
let (core_indx, rd, r1) = data.get_core_rd_r1();
|
||||||
@@ -100,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_sadd(cores, data);
|
TRACER.lock().unwrap().pre_sadd(cores, data);
|
||||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||||
@@ -110,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_ssub(cores, data);
|
TRACER.lock().unwrap().pre_ssub(cores, data);
|
||||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||||
@@ -120,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_smul(cores, data);
|
TRACER.lock().unwrap().pre_smul(cores, data);
|
||||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||||
@@ -130,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_saddi(cores, data);
|
TRACER.lock().unwrap().pre_saddi(cores, data);
|
||||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||||
@@ -139,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_smuli(cores, data);
|
TRACER.lock().unwrap().pre_smuli(cores, data);
|
||||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||||
@@ -213,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
|
|||||||
functor as usize == setbw as *const () as usize
|
functor as usize == setbw as *const () as usize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
|
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn mvm_impl_internal<F, M, T>(
|
pub(super) fn mvm_impl_internal<F, M, T>(
|
||||||
cores: &mut CPU,
|
cores: &mut CPU,
|
||||||
data: InstructionData,
|
data: InstructionData,
|
||||||
@@ -229,25 +243,30 @@ where
|
|||||||
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
[M]: UpcastSlice<T>,
|
[M]: UpcastSlice<T>,
|
||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
// Add faer::ComplexField HERE, directly bounding M for this function only
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
|
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
|
||||||
|
|
||||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||||
let group: usize = group.try_into().context("group can not be negative")?;
|
let group: usize = group.try_into().context("group can not be negative")?;
|
||||||
|
|
||||||
let core = cores.core(core_indx);
|
let core = cores.core(core_indx);
|
||||||
let r1_val = core.register(r1);
|
let r1_val = core.register(r1);
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
|
|
||||||
let (memory, crossbars) = core.get_memory_crossbar();
|
let (memory, crossbars) = core.get_memory_crossbar();
|
||||||
let crossbar = crossbars.get_mut(group).unwrap();
|
let crossbar = crossbars.get_mut(group).unwrap();
|
||||||
let crossbar_stored_bytes = crossbar.stored_bytes();
|
let crossbar_stored_bytes = crossbar.stored_bytes();
|
||||||
let crossbar_byte_width = crossbar.width();
|
let crossbar_byte_width = crossbar.width();
|
||||||
//Fix this
|
|
||||||
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
||||||
ensure!(
|
ensure!(
|
||||||
crossbar_byte_width & size_of::<M>() == 0,
|
crossbar_byte_width % size_of::<M>() == 0,
|
||||||
"M not divisor of the crosbbar size"
|
"M not divisor of the crosbbar size"
|
||||||
);
|
);
|
||||||
|
|
||||||
let crossbar_height = crossbar.height();
|
let crossbar_height = crossbar.height();
|
||||||
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
||||||
|
|
||||||
@@ -257,19 +276,29 @@ where
|
|||||||
let load = loads[0];
|
let load = loads[0];
|
||||||
let vec: Cow<[M]> = load.up();
|
let vec: Cow<[M]> = load.up();
|
||||||
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
||||||
let mut res = Vec::with_capacity(crossbar_elem_width);
|
|
||||||
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
|
|
||||||
partial.resize(vec.len(), M::from_f32(0.0));
|
|
||||||
|
|
||||||
for x in 0..crossbar_elem_width {
|
// --- FAER IMPLEMENTATION ---
|
||||||
partial[0] = vec[0] * matrix[x];
|
|
||||||
for y in 1..crossbar_height {
|
// 1. Explicitly create a Matrix Reference (MatRef)
|
||||||
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
|
let matrix_view = faer::mat::MatRef::from_row_major_slice(
|
||||||
}
|
matrix.as_ref(),
|
||||||
|
crossbar_height,
|
||||||
|
crossbar_elem_width,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Explicitly create a Column Vector Reference (ColRef)
|
||||||
|
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
|
||||||
|
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
|
||||||
|
|
||||||
|
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
|
||||||
|
|
||||||
|
// 4. Convert back to standard Rust Vec
|
||||||
|
// try_as_slice() returns an Option<&[M]>.
|
||||||
|
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
|
||||||
|
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
|
||||||
|
|
||||||
|
// --- END FAER ---
|
||||||
|
|
||||||
let mut acc = add_all(partial.as_slice());
|
|
||||||
res.push(acc);
|
|
||||||
}
|
|
||||||
if relu != 0 {
|
if relu != 0 {
|
||||||
res.iter_mut().for_each(|x| {
|
res.iter_mut().for_each(|x| {
|
||||||
if *x < M::from_f32(0.0) {
|
if *x < M::from_f32(0.0) {
|
||||||
@@ -277,16 +306,20 @@ where
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
ensure!(
|
ensure!(
|
||||||
res.len() == crossbar_elem_width,
|
res.len() == crossbar_elem_width,
|
||||||
"mvm generate a vector bigger thant it's requested elements"
|
"mvm generate a vector bigger thant it's requested elements"
|
||||||
);
|
);
|
||||||
|
|
||||||
let res_up: Cow<[T]> = res.as_slice().up();
|
let res_up: Cow<[T]> = res.as_slice().up();
|
||||||
core.execute_store(rd_val, res_up.as_ref());
|
core.execute_store(rd_val, res_up.as_ref());
|
||||||
|
|
||||||
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
|
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
|
||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
|
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
|
||||||
@@ -307,10 +340,12 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -349,10 +384,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -394,6 +431,7 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
|||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -430,10 +468,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -466,10 +506,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -503,22 +545,26 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!(
|
panic!(
|
||||||
"Shift left on floating point what does it means? who has generated this instruction???"
|
"Shift left on floating point what does it means? who has generated this instruction???"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!(
|
panic!(
|
||||||
"Shift right on floating point what does it means? who has generated this instruction???"
|
"Shift right on floating point what does it means? who has generated this instruction???"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -533,7 +579,10 @@ where
|
|||||||
let r2_val = r2;
|
let r2_val = r2;
|
||||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
ensure!(
|
||||||
|
offset_select == 1,
|
||||||
|
"Offset select cannot be different from 1"
|
||||||
|
);
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||||
let load1 = loads[0];
|
let load1 = loads[0];
|
||||||
@@ -545,10 +594,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -575,10 +626,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -603,10 +656,12 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
@@ -629,11 +684,16 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
panic!("You are calling a placeholder, the real call is the generic version");
|
panic!("You are calling a placeholder, the real call is the generic version");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
#[inline(never)]
|
||||||
|
pub(super) fn vsoftmax_impl<F, T>(
|
||||||
|
cores: &mut CPU,
|
||||||
|
data: InstructionData,
|
||||||
|
) -> Result<InstructionStatus>
|
||||||
where
|
where
|
||||||
[F]: UpcastSlice<T>,
|
[F]: UpcastSlice<T>,
|
||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
@@ -656,12 +716,11 @@ where
|
|||||||
.reduce(|a, b| if a > b { a } else { b })
|
.reduce(|a, b| if a > b { a } else { b })
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
||||||
let sum = exp_values
|
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
|
||||||
.iter()
|
ensure!(
|
||||||
.copied()
|
sum > 0.0.into(),
|
||||||
.reduce(|a, b| a + b)
|
"vsoftmax normalization sum must be positive"
|
||||||
.unwrap();
|
);
|
||||||
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
|
|
||||||
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
||||||
let res_up: Cow<[T]> = res.as_slice().up();
|
let res_up: Cow<[T]> = res.as_slice().up();
|
||||||
core.execute_store(rd_val, res_up.as_ref());
|
core.execute_store(rd_val, res_up.as_ref());
|
||||||
@@ -669,14 +728,17 @@ where
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
@@ -684,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
///Communication/synchronization Instructions/////////////////
|
///Communication/synchronization Instructions/////////////////
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
|
#[inline(never)]
|
||||||
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_ld(cores, data);
|
TRACER.lock().unwrap().pre_ld(cores, data);
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
@@ -700,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_st(cores, data);
|
TRACER.lock().unwrap().pre_st(cores, data);
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
@@ -716,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_lldi(cores, data);
|
TRACER.lock().unwrap().pre_lldi(cores, data);
|
||||||
let (core, rd, imm) = data.get_core_rd_imm();
|
let (core, rd, imm) = data.get_core_rd_imm();
|
||||||
@@ -732,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_lmv(cores, data);
|
TRACER.lock().unwrap().pre_lmv(cores, data);
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
@@ -748,20 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn isa_send(functor : usize) -> bool{
|
||||||
|
(send as *const () as usize) == functor
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_send(cores, data);
|
|
||||||
Ok(InstructionStatus::Sending(data))
|
Ok(InstructionStatus::Sending(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
|
pub fn isa_recv(functor : usize) -> bool{
|
||||||
|
(recv as *const () as usize) == functor
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
TRACER.lock().unwrap().pre_recv(cores, data);
|
|
||||||
Ok(InstructionStatus::Reciving(data))
|
Ok(InstructionStatus::Reciving(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
Ok(InstructionStatus::Waiting(data))
|
Ok(InstructionStatus::Waiting(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(never)]
|
||||||
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||||
Ok(InstructionStatus::Sync(data))
|
Ok(InstructionStatus::Sync(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ pub mod helper;
|
|||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
pub struct Instruction {
|
pub struct Instruction {
|
||||||
pub data: InstructionData,
|
pub data: InstructionData,
|
||||||
functor: InstructionType,
|
pub functor: InstructionType,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Default)]
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
|||||||
@@ -567,7 +567,7 @@ fn json_to_send(
|
|||||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd(rd)
|
.set_rd(rd)
|
||||||
.set_imm_core(core)
|
.set_imm_core(core + 1)
|
||||||
.set_imm_len(size)
|
.set_imm_len(size)
|
||||||
.set_offset_select(offset_select)
|
.set_offset_select(offset_select)
|
||||||
.set_offset_value(offset_value);
|
.set_offset_value(offset_value);
|
||||||
@@ -588,7 +588,7 @@ fn json_to_recv(
|
|||||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd(rd)
|
.set_rd(rd)
|
||||||
.set_imm_core(core)
|
.set_imm_core(core + 1)
|
||||||
.set_imm_len(size)
|
.set_imm_len(size)
|
||||||
.set_offset_select(offset_select)
|
.set_offset_select(offset_select)
|
||||||
.set_offset_value(offset_value);
|
.set_offset_value(offset_value);
|
||||||
|
|||||||
+15
-28
@@ -1,45 +1,32 @@
|
|||||||
use core::panic;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::{fs::File, io::BufReader};
|
||||||
|
|
||||||
use serde_json::{Map, Value};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CoreInstructionsBuilder, Executable,
|
CoreInstructionsBuilder, Executable,
|
||||||
cpu::{CPU, crossbar::{self, Crossbar}},
|
cpu::{CPU, crossbar::Crossbar},
|
||||||
instruction_set::{
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||||
InstructionsBuilder,
|
json_to_instruction::json_isa,
|
||||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
|
||||||
},
|
|
||||||
json_to_instruction::{self, json_isa},
|
|
||||||
memory_manager::type_traits::TryToUsize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub fn json_to_executor<'a, 'b>(
|
||||||
pub fn json_to_executor<'a>(
|
|
||||||
config: Value,
|
config: Value,
|
||||||
mut cores: impl Iterator<Item = &'a Value>,
|
cores: &'b mut Vec<BufReader<File>>,
|
||||||
crossbars : Vec<Vec<&'a Crossbar>>
|
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||||
) -> Executable<'a> {
|
) -> Executable<'a> {
|
||||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
|
||||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
|
||||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
|
||||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
|
||||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
|
||||||
|
|
||||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
let cpu = CPU::new(core_cnt, crossbars);
|
||||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||||
cores.next();
|
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
|
||||||
for core_indx in 1..=core_cnt {
|
let core_indx = external_core_indx as i32 + 1;
|
||||||
let mut insts_builder = InstructionsBuilder::new();
|
let mut insts_builder = InstructionsBuilder::new();
|
||||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||||
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
||||||
let json_core = cores
|
let json_core: Value = serde_json::from_reader(json_core_reader)
|
||||||
.next()
|
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
|
||||||
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
|
|
||||||
let json_core_insts = json_core
|
let json_core_insts = json_core
|
||||||
.as_array()
|
.as_array()
|
||||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
|
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
|
||||||
for json_inst in json_core_insts {
|
for json_inst in json_core_insts {
|
||||||
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
mod json_isa;
|
pub(crate) mod json_isa;
|
||||||
pub mod json_to_executor;
|
pub mod json_to_executor;
|
||||||
|
|||||||
@@ -55,17 +55,25 @@ pub trait HasSigm {
|
|||||||
|
|
||||||
impl HasSigm for f32 {
|
impl HasSigm for f32 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
|
if self >= 0.0 {
|
||||||
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
let ex = self.exp();
|
let ex = self.exp();
|
||||||
ex / (1.0 + ex)
|
ex / (1.0 + ex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl HasSigm for f64 {
|
impl HasSigm for f64 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
|
if self >= 0.0 {
|
||||||
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
let ex = self.exp();
|
let ex = self.exp();
|
||||||
ex / (1.0 + ex)
|
ex / (1.0 + ex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait HasExp {
|
pub trait HasExp {
|
||||||
fn exp(self) -> Self;
|
fn exp(self) -> Self;
|
||||||
|
|||||||
@@ -1,50 +1,62 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
use crate::{
|
use anyhow::{Result, bail};
|
||||||
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
cpu::CPU,
|
||||||
|
instruction_set::{
|
||||||
|
Instruction, InstructionStatus, Instructions,
|
||||||
|
isa::{NAMES, functor_to_name, isa_recv, isa_send},
|
||||||
|
},
|
||||||
|
memory_manager::type_traits::TryToUsize,
|
||||||
|
send_recv::{SendRecv, handle_send_recv},
|
||||||
|
tracing::TRACER,
|
||||||
|
};
|
||||||
|
pub mod binary_to_instruction;
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod instruction_set;
|
pub mod instruction_set;
|
||||||
|
pub mod json_to_instruction;
|
||||||
pub mod memory_manager;
|
pub mod memory_manager;
|
||||||
pub mod send_recv;
|
pub mod send_recv;
|
||||||
pub mod utility;
|
|
||||||
pub mod json_to_instruction;
|
|
||||||
pub mod tracing;
|
pub mod tracing;
|
||||||
|
pub mod utility;
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CoreInstructionsBuilder {
|
pub struct CoreInstructionsBuilder {
|
||||||
core_instructions : Vec<CoreInstruction>
|
core_instructions: Vec<CoreInstructions>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CoreInstructionsBuilder {
|
impl CoreInstructionsBuilder {
|
||||||
pub fn new(size: usize) -> Self {
|
pub fn new(size: usize) -> Self {
|
||||||
let mut core_instructions = Vec::with_capacity(size);
|
let mut core_instructions = Vec::with_capacity(size);
|
||||||
for _ in 0..=size {
|
for _ in 0..=size {
|
||||||
core_instructions.push(CoreInstruction::empty());
|
core_instructions.push(CoreInstructions::empty());
|
||||||
}
|
}
|
||||||
Self { core_instructions }
|
Self { core_instructions }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(self) -> Vec<CoreInstruction> {
|
pub fn build(self) -> Vec<CoreInstructions> {
|
||||||
self.core_instructions
|
self.core_instructions
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self {
|
pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self {
|
||||||
self.core_instructions[core.try_into().expect("Set core with not valid size")] = core_instruction.into();
|
self.core_instructions[core.try_into().expect("Set core with not valid size")] =
|
||||||
|
core_instruction.into();
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CoreInstruction {
|
pub struct CoreInstructions {
|
||||||
instructions: Instructions,
|
instructions: Instructions,
|
||||||
program_counter: usize,
|
program_counter: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CoreInstruction {
|
impl CoreInstructions {
|
||||||
fn new(instructions: Instructions, program_counter: usize) -> Self {
|
fn new(instructions: Instructions, program_counter: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
instructions,
|
instructions,
|
||||||
@@ -53,13 +65,16 @@ impl CoreInstruction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn empty() -> Self {
|
fn empty() -> Self {
|
||||||
Self { instructions: Vec::new(), program_counter: 0 }
|
Self {
|
||||||
|
instructions: Vec::new(),
|
||||||
|
program_counter: 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Instructions> for CoreInstruction {
|
impl From<Instructions> for CoreInstructions {
|
||||||
fn from(value: Instructions) -> Self {
|
fn from(value: Instructions) -> Self {
|
||||||
CoreInstruction {
|
CoreInstructions {
|
||||||
instructions: value,
|
instructions: value,
|
||||||
program_counter: 0,
|
program_counter: 0,
|
||||||
}
|
}
|
||||||
@@ -69,39 +84,67 @@ impl From<Instructions> for CoreInstruction {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Executable<'a> {
|
pub struct Executable<'a> {
|
||||||
cpu: CPU<'a>,
|
cpu: CPU<'a>,
|
||||||
core_instructions: Vec<CoreInstruction>,
|
core_instructions: Vec<CoreInstructions>,
|
||||||
send_recv: SendRecv,
|
send_recv: SendRecv,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DeadlockInfo {
|
||||||
|
cycle: String,
|
||||||
|
states: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||||
|
let mut tot_instructions = 0;
|
||||||
|
let mut progress = 0;
|
||||||
|
for core_instruction in core_instructions.iter() {
|
||||||
|
tot_instructions += core_instruction.instructions.len();
|
||||||
|
progress += core_instruction.program_counter;
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Progress: {}% ({}/{}) ",
|
||||||
|
progress as f32 / tot_instructions as f32 * 100.0,
|
||||||
|
progress,
|
||||||
|
tot_instructions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> Executable<'a> {
|
impl<'a> Executable<'a> {
|
||||||
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstruction>) -> Executable<'a> {
|
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstructions>) -> Executable<'a> {
|
||||||
let num_core = cpu.num_core();
|
let num_core = cpu.num_core();
|
||||||
let send_recv = SendRecv::new(num_core);
|
let send_recv = SendRecv::new(num_core);
|
||||||
assert_eq!(num_core, core_instructions.len(), "Some core doesn't have is list of istruction (required even if empty)");
|
assert_eq!(
|
||||||
|
num_core,
|
||||||
|
core_instructions.len(),
|
||||||
|
"Some core doesn't have is list of istruction (required even if empty)"
|
||||||
|
);
|
||||||
Self {
|
Self {
|
||||||
cpu,
|
cpu,
|
||||||
core_instructions,
|
core_instructions,
|
||||||
send_recv
|
send_recv,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute<'b>(&'b mut self)
|
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||||
where 'a : 'b
|
where
|
||||||
|
'a: 'b,
|
||||||
{
|
{
|
||||||
let Self {
|
let Self {
|
||||||
cpu,
|
cpu,
|
||||||
core_instructions,
|
core_instructions: cores_instructions,
|
||||||
send_recv
|
send_recv,
|
||||||
} = self;
|
} = self;
|
||||||
let mut cpu_progressed = 0;
|
let mut cpu_progressed = 0;
|
||||||
let max_core = cpu.num_core();
|
let max_core = cpu.num_core();
|
||||||
let mut index_unit = 0;
|
let mut cpu_index = 0;
|
||||||
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
while (cpu_progressed > -2) {
|
while (cpu_progressed > -2) {
|
||||||
let mut core_result = InstructionStatus::Completed;
|
let mut core_result = InstructionStatus::Completed;
|
||||||
while core_result.is_completed() && let Some(core_instruction) = core_instructions.get_mut(index_unit){
|
while core_result.is_completed()
|
||||||
|
&& let Some(core_instruction) = cores_instructions.get_mut(cpu_index)
|
||||||
|
{
|
||||||
core_result = InstructionStatus::NotExecuted;
|
core_result = InstructionStatus::NotExecuted;
|
||||||
let CoreInstruction {
|
let CoreInstructions {
|
||||||
instructions,
|
instructions,
|
||||||
program_counter,
|
program_counter,
|
||||||
} = core_instruction;
|
} = core_instruction;
|
||||||
@@ -114,16 +157,56 @@ impl<'a> Executable<'a> {
|
|||||||
cpu_progressed = 0;
|
cpu_progressed = 0;
|
||||||
*program_counter += 1;
|
*program_counter += 1;
|
||||||
}
|
}
|
||||||
|
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||||
|
print_status(cores_instructions);
|
||||||
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if handle_send_recv(cpu, core_instructions, send_recv, core_result) { cpu_progressed = 0; }
|
now = SystemTime::now();
|
||||||
handle_wait_sync(cpu, core_instructions, core_result);
|
}
|
||||||
index_unit = if index_unit + 1 >= max_core {
|
}
|
||||||
|
handle_wait_sync(cpu, cores_instructions, core_result);
|
||||||
|
match handle_send_recv(cpu, cores_instructions, send_recv, core_result) {
|
||||||
|
(true, other_cpu_index) => {
|
||||||
|
cpu_progressed = 0;
|
||||||
|
cpu_index = other_cpu_index;
|
||||||
|
}
|
||||||
|
(false, 0) => {
|
||||||
|
cpu_index = if cpu_index + 1 >= cores_instructions.len() {
|
||||||
cpu_progressed -= 1;
|
cpu_progressed -= 1;
|
||||||
0
|
0
|
||||||
} else {
|
} else {
|
||||||
index_unit + 1
|
cpu_index + 1
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
(false, other_cpu_index) => {
|
||||||
|
cpu_index = other_cpu_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
print_status(cores_instructions);
|
||||||
|
|
||||||
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cores_instructions
|
||||||
|
.iter()
|
||||||
|
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||||
|
{
|
||||||
|
bail!("Execution stalled with unfinished instructions");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
TRACER.lock().unwrap().report();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cpu(&self) -> &CPU<'a> {
|
pub fn cpu(&self) -> &CPU<'a> {
|
||||||
@@ -145,13 +228,130 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_wait_sync<'a, 'b, 'c >(cpu: &'b mut CPU<'a>, core_instructions: &'c mut [CoreInstruction], core_result: InstructionStatus)
|
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||||
where 'a : 'b,
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
'a : 'c
|
enum CoreState {
|
||||||
{
|
SendingTo(i32, i32),
|
||||||
|
ReceivingFrom(i32, i32),
|
||||||
|
Working,
|
||||||
|
Halted,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut states = HashMap::new();
|
||||||
|
|
||||||
|
for core_inst in cores_instructions.iter() {
|
||||||
|
if core_inst.program_counter >= core_inst.instructions.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
|
||||||
|
let functor_address = functor as usize;
|
||||||
|
|
||||||
|
let (this_core, target_core) = data.get_core_immcore();
|
||||||
|
|
||||||
|
if isa_recv(functor_address) {
|
||||||
|
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||||
|
} else if isa_send(functor_address) {
|
||||||
|
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||||
|
} else {
|
||||||
|
states.insert(this_core, CoreState::Working);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut wait_for = HashMap::new();
|
||||||
|
|
||||||
|
for (&core_id, state) in states.iter() {
|
||||||
|
match state {
|
||||||
|
CoreState::SendingTo(target_core, size) => {
|
||||||
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
|
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||||
|
wait_for.insert(core_id, *target_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CoreState::ReceivingFrom(target_core, size) => {
|
||||||
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
|
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||||
|
wait_for.insert(core_id, *target_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CoreState::Working | CoreState::Halted => {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut visited = HashSet::new();
|
||||||
|
|
||||||
|
for &start_core in wait_for.keys() {
|
||||||
|
if visited.contains(&start_core) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut path = Vec::new();
|
||||||
|
let mut current_core = start_core;
|
||||||
|
let mut in_path = HashSet::new();
|
||||||
|
|
||||||
|
while let Some(&waiting_for) = wait_for.get(¤t_core) {
|
||||||
|
path.push(current_core);
|
||||||
|
in_path.insert(current_core);
|
||||||
|
visited.insert(current_core);
|
||||||
|
|
||||||
|
// Found a closed loop!
|
||||||
|
if in_path.contains(&waiting_for) {
|
||||||
|
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||||
|
let cycle = &path[cycle_start..];
|
||||||
|
|
||||||
|
let cycle_str = cycle
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" -> ");
|
||||||
|
|
||||||
|
let cycle = cycle
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.chain(std::iter::once(waiting_for))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||||
|
let states_msg = cycle
|
||||||
|
.iter()
|
||||||
|
.filter_map(|core| {
|
||||||
|
states.get(core).map(|state| match state {
|
||||||
|
CoreState::SendingTo(target, size) => {
|
||||||
|
format!("core {} send {}B -> {}", core, size, target)
|
||||||
|
}
|
||||||
|
CoreState::ReceivingFrom(source, size) => {
|
||||||
|
format!("core {} recv {}B <- {}", core, size, source)
|
||||||
|
}
|
||||||
|
CoreState::Working => format!("core {} working", core),
|
||||||
|
CoreState::Halted => format!("core {} halted", core),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
|
||||||
|
return Some(DeadlockInfo {
|
||||||
|
cycle: cycle_msg,
|
||||||
|
states: states_msg,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hit a known branch that didn't result in a cycle
|
||||||
|
if visited.contains(&waiting_for) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
current_core = waiting_for;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_wait_sync<'a, 'b, 'c>(
|
||||||
|
cpu: &'b mut CPU<'a>,
|
||||||
|
core_instructions: &'c mut [CoreInstructions],
|
||||||
|
core_result: InstructionStatus,
|
||||||
|
) where
|
||||||
|
'a: 'b,
|
||||||
|
'a: 'c,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CoreInstruction, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
CoreInstructions, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
||||||
utility::add_offset_rd,
|
utility::add_offset_rd,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -43,14 +43,14 @@ impl SendRecv {
|
|||||||
|
|
||||||
pub fn handle_send_recv<'a, 'b >(
|
pub fn handle_send_recv<'a, 'b >(
|
||||||
cpu: &'b mut CPU<'a>,
|
cpu: &'b mut CPU<'a>,
|
||||||
core_instructions: & mut [CoreInstruction],
|
core_instructions: & mut [CoreInstructions],
|
||||||
send_recv: & mut SendRecv,
|
send_recv: & mut SendRecv,
|
||||||
core_result: InstructionStatus,
|
core_result: InstructionStatus,
|
||||||
) -> bool
|
) -> (bool, usize)
|
||||||
where 'a : 'b
|
where 'a : 'b
|
||||||
{
|
{
|
||||||
let transfer_memory = |cpu: &'b mut CPU<'a>,
|
let transfer_memory = |cpu: &'b mut CPU<'a>,
|
||||||
core_instructions: & mut [CoreInstruction],
|
core_instructions: & mut [CoreInstructions],
|
||||||
sender: Option<SendRecvInfo>,
|
sender: Option<SendRecvInfo>,
|
||||||
receiver: Option<SendRecvInfo>| {
|
receiver: Option<SendRecvInfo>| {
|
||||||
if let Some(sender) = sender
|
if let Some(sender) = sender
|
||||||
@@ -58,6 +58,20 @@ where 'a : 'b
|
|||||||
&& sender.internal_core == receiver.external_core
|
&& sender.internal_core == receiver.external_core
|
||||||
&& receiver.internal_core == sender.external_core
|
&& receiver.internal_core == sender.external_core
|
||||||
{
|
{
|
||||||
|
{
|
||||||
|
let sender = &mut core_instructions[sender.internal_core];
|
||||||
|
let pc = sender.program_counter;
|
||||||
|
let inst = sender.instructions.get(pc).unwrap();
|
||||||
|
let data = inst.data;
|
||||||
|
TRACER.lock().unwrap().pre_send(cpu, data);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let recv = &mut core_instructions[receiver.internal_core];
|
||||||
|
let pc = recv.program_counter;
|
||||||
|
let inst = recv.instructions.get(pc).unwrap();
|
||||||
|
let data = inst.data;
|
||||||
|
TRACER.lock().unwrap().pre_recv(cpu, data);
|
||||||
|
}
|
||||||
let [sender_core, reciver_core] =
|
let [sender_core, reciver_core] =
|
||||||
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
|
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
|
||||||
let memory = sender_core
|
let memory = sender_core
|
||||||
@@ -119,7 +133,7 @@ where 'a : 'b
|
|||||||
send_recv.sending[sender] = None;
|
send_recv.sending[sender] = None;
|
||||||
send_recv.receiving[receiver] = None;
|
send_recv.receiving[receiver] = None;
|
||||||
}
|
}
|
||||||
transfered
|
(transfered, receiver)
|
||||||
}
|
}
|
||||||
InstructionStatus::Reciving(instruction_data) => {
|
InstructionStatus::Reciving(instruction_data) => {
|
||||||
let (core_idx, imm_core) = instruction_data.get_core_immcore();
|
let (core_idx, imm_core) = instruction_data.get_core_immcore();
|
||||||
@@ -148,8 +162,8 @@ where 'a : 'b
|
|||||||
send_recv.sending[sender] = None;
|
send_recv.sending[sender] = None;
|
||||||
send_recv.receiving[receiver] = None;
|
send_recv.receiving[receiver] = None;
|
||||||
}
|
}
|
||||||
transfered
|
(transfered, sender)
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => (false, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
impl Trace {
|
impl Trace {
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
|
|||||||
@@ -1,53 +1,32 @@
|
|||||||
mod tracing_isa;
|
|
||||||
mod disable;
|
mod disable;
|
||||||
mod pretty_print;
|
#[cfg(feature = "profile_time")]
|
||||||
|
mod profile;
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
use profile::Trace;
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
use std::{fs::File, path::{ PathBuf}};
|
mod trace;
|
||||||
use std::sync::{LazyLock, Mutex};
|
#[cfg(feature = "tracing")]
|
||||||
|
use trace::Trace;
|
||||||
|
|
||||||
use crate::Executable;
|
use crate::Executable;
|
||||||
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
#[cfg(feature = "tracing")]
|
use std::path::PathBuf;
|
||||||
pub struct Trace {
|
use std::sync::{LazyLock, Mutex};
|
||||||
out_files : Vec<File>
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
impl Trace {
|
pub struct Trace {}
|
||||||
fn new() -> Self {
|
|
||||||
Self { out_files : Vec::new()}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||||
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
|
|
||||||
path.pop();
|
|
||||||
for i in 0..num_core {
|
|
||||||
path.push(format!("TraceCore{}", i));
|
|
||||||
let file = File::create(&path).expect("Can not create file");
|
|
||||||
self.out_files.push(file);
|
|
||||||
path.pop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
|
||||||
pub struct Trace {
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#[cfg(not(feature = "tracing"))]
|
|
||||||
impl Trace {
|
impl Trace {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {}
|
Self {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
|
||||||
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
|
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,73 @@
|
|||||||
|
use std::{collections::HashMap, path::PathBuf, time::Instant};
|
||||||
|
|
||||||
|
use crate::tracing::profile::profile_analysis::{
|
||||||
|
analyze_timings, generate_interactive_report, print_textual_report,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub mod profile_analysis;
|
||||||
|
pub mod profile_isa;
|
||||||
|
|
||||||
|
pub struct Trace {
|
||||||
|
instruction_times: HashMap<String, Vec<(u128,u128)>>,
|
||||||
|
core_start_time: HashMap<usize, Option<Instant>>,
|
||||||
|
start_time: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Trace {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let mut instruction_times = HashMap::new();
|
||||||
|
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
|
||||||
|
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
|
||||||
|
Self {
|
||||||
|
instruction_times,
|
||||||
|
core_start_time: HashMap::new(),
|
||||||
|
start_time: Instant::now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, path: PathBuf) {
|
||||||
|
for i in 0..num_core {
|
||||||
|
self.core_start_time.insert(i, None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn report(&self) {
|
||||||
|
let res = analyze_timings(&self.instruction_times);
|
||||||
|
print_textual_report(&res);
|
||||||
|
generate_interactive_report(
|
||||||
|
&self.instruction_times,
|
||||||
|
&["mvmul", "recv"],
|
||||||
|
"/tmp/report.html",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,192 @@
|
|||||||
|
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
|
||||||
|
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InstructionStats {
|
||||||
|
pub name: String,
|
||||||
|
pub count: usize,
|
||||||
|
pub total_time: u128,
|
||||||
|
pub min: f64,
|
||||||
|
pub max: f64,
|
||||||
|
pub mean: f64,
|
||||||
|
pub median: f64,
|
||||||
|
pub std_dev: f64,
|
||||||
|
pub cv: f64,
|
||||||
|
pub p95: f64,
|
||||||
|
pub p99: f64,
|
||||||
|
pub skewness: f64,
|
||||||
|
pub kurtosis: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_time(ns: f64) -> String {
|
||||||
|
if ns.is_nan() {
|
||||||
|
return "NaN".to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ns >= 1_000_000_000.0 {
|
||||||
|
format!("{:.2} s", ns / 1_000_000_000.0)
|
||||||
|
} else if ns >= 1_000_000.0 {
|
||||||
|
format!("{:.2} ms", ns / 1_000_000.0)
|
||||||
|
} else if ns >= 1_000.0 {
|
||||||
|
format!("{:.2} µs", ns / 1_000.0)
|
||||||
|
} else {
|
||||||
|
format!("{:.2} ns", ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
|
||||||
|
let n = times.len() as f64;
|
||||||
|
|
||||||
|
if n < 4.0 || std_dev == 0.0 {
|
||||||
|
return (f64::NAN, f64::NAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum_m3 = 0.0;
|
||||||
|
let mut sum_m4 = 0.0;
|
||||||
|
|
||||||
|
for &x in times {
|
||||||
|
let deviation = x - mean;
|
||||||
|
sum_m3 += deviation.powi(3);
|
||||||
|
sum_m4 += deviation.powi(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
let m3 = sum_m3 / n;
|
||||||
|
let m4 = sum_m4 / n;
|
||||||
|
|
||||||
|
let skewness = m3 / std_dev.powi(3);
|
||||||
|
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
|
||||||
|
|
||||||
|
(skewness, kurtosis)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
for (instruction, times) in timings {
|
||||||
|
let count = times.len();
|
||||||
|
if count == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract ONLY the duration (the second element of the tuple) for stats
|
||||||
|
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
|
||||||
|
let total_time: u128 = durations.iter().sum();
|
||||||
|
|
||||||
|
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
|
||||||
|
let mut data = Data::new(f64_times.clone());
|
||||||
|
|
||||||
|
let mean = data.mean().unwrap_or(0.0);
|
||||||
|
let std_dev = data.std_dev().unwrap_or(0.0);
|
||||||
|
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
|
||||||
|
|
||||||
|
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
|
||||||
|
|
||||||
|
results.push(InstructionStats {
|
||||||
|
name: instruction.clone(),
|
||||||
|
count,
|
||||||
|
total_time,
|
||||||
|
min: data.min(),
|
||||||
|
max: data.max(),
|
||||||
|
mean,
|
||||||
|
median: data.median(),
|
||||||
|
std_dev,
|
||||||
|
cv,
|
||||||
|
p95: data.percentile(95),
|
||||||
|
p99: data.percentile(99),
|
||||||
|
skewness,
|
||||||
|
kurtosis,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn print_textual_report(stats: &[InstructionStats]) {
|
||||||
|
let mut table = Table::new();
|
||||||
|
table
|
||||||
|
.load_preset(UTF8_FULL)
|
||||||
|
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||||
|
.set_header(vec![
|
||||||
|
"Instruction",
|
||||||
|
"Count",
|
||||||
|
"Total Time",
|
||||||
|
"Mean",
|
||||||
|
"Median",
|
||||||
|
"Min",
|
||||||
|
"Max",
|
||||||
|
"P95",
|
||||||
|
"P99",
|
||||||
|
"StdDev",
|
||||||
|
"CV",
|
||||||
|
"Skewness",
|
||||||
|
"Kurtosis",
|
||||||
|
]);
|
||||||
|
|
||||||
|
for stat in stats {
|
||||||
|
table.add_row(vec![
|
||||||
|
Cell::new(&stat.name),
|
||||||
|
Cell::new(stat.count.to_string()),
|
||||||
|
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
|
||||||
|
Cell::new(format_time(stat.mean)),
|
||||||
|
Cell::new(format_time(stat.median)),
|
||||||
|
Cell::new(format_time(stat.min)),
|
||||||
|
Cell::new(format_time(stat.max)),
|
||||||
|
Cell::new(format_time(stat.p95)),
|
||||||
|
Cell::new(format_time(stat.p99)),
|
||||||
|
Cell::new(format_time(stat.std_dev)),
|
||||||
|
Cell::new(format!("{:.3}", stat.cv)),
|
||||||
|
Cell::new(format!("{:.2}", stat.skewness)),
|
||||||
|
Cell::new(format!("{:.2}", stat.kurtosis)),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{table}");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn generate_interactive_report(
|
||||||
|
timings: &HashMap<String, Vec<(u128, u128)>>,
|
||||||
|
instructions_to_plot: &[&str], // <-- NEW: Only plot these
|
||||||
|
file_path: &str,
|
||||||
|
) {
|
||||||
|
|
||||||
|
use plotly::common::{Mode, Marker, Line};
|
||||||
|
use plotly::layout::{Axis, Layout};
|
||||||
|
use plotly::{Plot, Scatter};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
let mut plot = Plot::new();
|
||||||
|
|
||||||
|
for &instruction_name in instructions_to_plot {
|
||||||
|
// Only proceed if the instruction exists in our timings map
|
||||||
|
if let Some(times) = timings.get(instruction_name) {
|
||||||
|
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
|
||||||
|
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
|
||||||
|
|
||||||
|
let text_array: Vec<String> = times.iter()
|
||||||
|
.map(|&(_, dur)| format_time(dur as f64))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let trace = Scatter::new(x_axis, y_axis)
|
||||||
|
.name(instruction_name)
|
||||||
|
.mode(Mode::LinesMarkers)
|
||||||
|
.marker(Marker::new().size(4).opacity(0.6))
|
||||||
|
.line(Line::new().width(1.0))
|
||||||
|
.text_array(text_array)
|
||||||
|
.hover_info(plotly::common::HoverInfo::All);
|
||||||
|
|
||||||
|
plot.add_trace(trace);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let layout = Layout::new()
|
||||||
|
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
|
||||||
|
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
|
||||||
|
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
|
||||||
|
|
||||||
|
plot.set_layout(layout);
|
||||||
|
plot.write_html(file_path);
|
||||||
|
println!("🌐 Interactive timeline saved to {}", file_path);
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,364 @@
|
|||||||
|
use crate::{
|
||||||
|
cpu::CPU,
|
||||||
|
instruction_set::instruction_data::InstructionData,
|
||||||
|
memory_manager::{
|
||||||
|
MemoryStorable,
|
||||||
|
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||||
|
},
|
||||||
|
tracing::Trace,
|
||||||
|
utility::{add_offset_r1, add_offset_rd},
|
||||||
|
};
|
||||||
|
use std::io::Write;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
#[cfg(feature = "profile_time")]
|
||||||
|
impl Trace {
|
||||||
|
///////////////////////////////////////////////////////////////
|
||||||
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
|
///////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
|
let core_indx = core_indx as usize;
|
||||||
|
if self.core_start_time.get(&core_indx).unwrap().is_none() {
|
||||||
|
self.core_start_time.insert(core_indx, Some(Instant::now()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
|
||||||
|
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||||
|
let core_indx = core_indx as usize;
|
||||||
|
let Self {
|
||||||
|
instruction_times,
|
||||||
|
core_start_time,
|
||||||
|
start_time,
|
||||||
|
} = self;
|
||||||
|
let now = Instant::now();
|
||||||
|
instruction_times
|
||||||
|
.get_mut(name)
|
||||||
|
.unwrap()
|
||||||
|
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
|
||||||
|
self.core_start_time.insert(core_indx, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sldi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sld");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "sadd");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "ssub");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "smul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "saddi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "smuli");
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
///////////////////Matrix/vector Instructions////////////////////
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "setbw");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
|
[M]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||||
|
[M]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "mvmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvadd");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvsub");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvdmul");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vvmax");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vavg");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vrelu");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vtanh");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vsigm");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||||
|
where
|
||||||
|
[F]: UpcastSlice<T>,
|
||||||
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
|
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||||
|
{
|
||||||
|
self.post_impl(cores, data, "vsoftmax");
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
/////Communication/synchronization Instructions/////////////////
|
||||||
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "ld");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "st");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "lldi");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "lmv");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "send");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.pre_impl(cores, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||||
|
self.post_impl(cores, data, "recv");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
use std::{fs::File, path::PathBuf};
|
||||||
|
|
||||||
|
pub mod pretty_print;
|
||||||
|
pub mod tracing_isa;
|
||||||
|
|
||||||
|
pub struct Trace {
|
||||||
|
out_files: Vec<File>,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl Trace {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
out_files: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
|
||||||
|
path.pop();
|
||||||
|
for i in 0..num_core {
|
||||||
|
path.push(format!("TraceCore{}", i));
|
||||||
|
let file = File::create(&path).expect("Can not create file");
|
||||||
|
self.out_files.push(file);
|
||||||
|
path.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
+1
-10
@@ -1,4 +1,4 @@
|
|||||||
use crate::tracing::pretty_print;
|
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -13,7 +13,6 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
#[cfg(feature = "tracing")]
|
|
||||||
impl Trace {
|
impl Trace {
|
||||||
///////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////
|
||||||
/////////////////Scalar/register Instructions//////////////////
|
/////////////////Scalar/register Instructions//////////////////
|
||||||
@@ -284,7 +283,6 @@ impl Trace {
|
|||||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -358,8 +356,6 @@ impl Trace {
|
|||||||
T: UpcastDestTraits<T> + MemoryStorable,
|
T: UpcastDestTraits<T> + MemoryStorable,
|
||||||
F: UpcastDestTraits<F> + MemoryStorable,
|
F: UpcastDestTraits<F> + MemoryStorable,
|
||||||
{
|
{
|
||||||
use crate::{tracing::pretty_print, utility::add_offset_r2};
|
|
||||||
|
|
||||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -990,8 +986,6 @@ impl Trace {
|
|||||||
/////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -1044,8 +1038,6 @@ impl Trace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
let file: &mut File = self
|
let file: &mut File = self
|
||||||
@@ -1138,7 +1130,6 @@ impl Trace {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||||
use crate::tracing::pretty_print;
|
|
||||||
|
|
||||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||||
data.get_core_rd_r1_r2_immlen_offset();
|
data.get_core_rd_r1_r2_immlen_offset();
|
||||||
@@ -1,6 +1,11 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
use pimcore::{
|
||||||
|
Executable,
|
||||||
|
cpu::crossbar::Crossbar,
|
||||||
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
|
memory_manager::CoreMemory,
|
||||||
|
};
|
||||||
|
|
||||||
fn simple_read(path: &Path) -> Vec<f32> {
|
fn simple_read(path: &Path) -> Vec<f32> {
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
|||||||
fn mvmul_f32(err: &str)
|
fn mvmul_f32(err: &str)
|
||||||
where
|
where
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
crossbar.execute_store(&matrix).unwrap();
|
||||||
let matrix = simple_read(Path::new("B.txt")) ;
|
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||||
|
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||||
|
let vector = simple_read(Path::new("tests/A.txt"));
|
||||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
|
||||||
let vector = simple_read(Path::new("A.txt"));
|
|
||||||
memory.execute_store(0, &vector).unwrap();
|
memory.execute_store(0, &vector).unwrap();
|
||||||
|
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
@@ -57,7 +60,7 @@ where
|
|||||||
.cpu_mut()
|
.cpu_mut()
|
||||||
.host()
|
.host()
|
||||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||||
"Wrong result for {}",
|
"Wrong result for {}",
|
||||||
err
|
err
|
||||||
);
|
);
|
||||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
|||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
use pimcore::cpu::CPU;
|
||||||
|
|
||||||
|
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||||
|
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||||
|
}
|
||||||
@@ -1,51 +1,103 @@
|
|||||||
use std::{fs, io::BufReader, path::Path};
|
use std::{
|
||||||
|
fs::{self, File},
|
||||||
|
io::BufReader,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use pimcore::json_to_instruction::json_to_executor;
|
use pimcore::{
|
||||||
|
cpu::crossbar::Crossbar,
|
||||||
|
json_to_instruction::json_to_executor,
|
||||||
|
memory_manager::CoreMemory,
|
||||||
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for entry in fs::read_dir(root)? {
|
for entry in fs::read_dir(root)? {
|
||||||
let entry = entry.context("Root not found")?;
|
let entry = entry.context("Root not found")?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
let mut cores = Vec::new();
|
result.push(path);
|
||||||
let mut config: Option<Value> = None;
|
|
||||||
for sub_entry in fs::read_dir(&path)
|
|
||||||
.with_context(|| format!("File {} not readable", path.display()))?
|
|
||||||
{
|
|
||||||
let sub_entry =
|
|
||||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
|
||||||
let sub_path = sub_entry.path();
|
|
||||||
if sub_path.is_file()
|
|
||||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
|
||||||
{
|
|
||||||
let file = fs::File::open(&sub_path)
|
|
||||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
|
||||||
"Serde reader fail for subpath {}",
|
|
||||||
sub_path.display()
|
|
||||||
))?;
|
|
||||||
if sub_path.file_name().unwrap() == "config.json" {
|
|
||||||
config = Some(val);
|
|
||||||
} else {
|
|
||||||
cores.push(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.push((config.unwrap(), cores));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn core_sort_key(path: &Path) -> i32 {
|
||||||
|
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||||
|
stem[5..].parse::<i32>().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||||
|
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||||
|
stem[9..].parse::<i32>().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||||
|
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||||
|
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||||
|
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||||
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||||
|
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||||
|
owned_crossbars.push(Vec::new());
|
||||||
|
|
||||||
|
for core_idx in 0..core_cnt {
|
||||||
|
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||||
|
let mut core_crossbars = Vec::new();
|
||||||
|
if core_folder.is_dir() {
|
||||||
|
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||||
|
.map(|entry| entry.map(|entry| entry.path()))
|
||||||
|
.collect::<std::io::Result<Vec<_>>>()?;
|
||||||
|
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||||
|
for path in paths {
|
||||||
|
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let bytes = fs::read(&path)
|
||||||
|
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||||
|
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||||
|
crossbar.execute_store(&bytes)?;
|
||||||
|
core_crossbars.push(crossbar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
owned_crossbars.push(core_crossbars);
|
||||||
|
}
|
||||||
|
Ok(owned_crossbars)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn json_folder_tester() {
|
fn json_folder_tester() {
|
||||||
let examples = collect_json_from_subfolders("data").unwrap();
|
let examples = collect_examples("tests/data").unwrap();
|
||||||
for example in examples {
|
for folder in examples {
|
||||||
let (config, cores) = example;
|
let config_path = folder.join("config.json");
|
||||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
let config_file = File::open(&config_path).unwrap();
|
||||||
|
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||||
|
|
||||||
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||||
|
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||||
|
.unwrap()
|
||||||
|
.map(|entry| entry.unwrap().path())
|
||||||
|
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||||
|
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||||
|
.collect();
|
||||||
|
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||||
|
assert_eq!(core_paths.len(), core_cnt);
|
||||||
|
|
||||||
|
let mut core_readers: Vec<_> = core_paths
|
||||||
|
.into_iter()
|
||||||
|
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||||
|
let crossbars = owned_crossbars
|
||||||
|
.iter()
|
||||||
|
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||||
|
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||||
|
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||||
|
executable.execute();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
use pimcore::{
|
||||||
|
Executable,
|
||||||
|
instruction_set::{
|
||||||
|
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Function not found for the requested size") ]
|
#[should_panic(expected = "Function not found for the requested size") ]
|
||||||
fn wrong_size_place_holder() {
|
fn wrong_size_place_holder() {
|
||||||
let cpu = CPU::new(0);
|
let cpu = common::empty_cpu(0);
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(0).fix_core_indx();
|
idata_build.set_core_indx(0).fix_core_indx();
|
||||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
|||||||
|
|
||||||
|
|
||||||
fn place_holder(inst : InstructionType) {
|
fn place_holder(inst : InstructionType) {
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(0).fix_core_indx();
|
idata_build.set_core_indx(0).fix_core_indx();
|
||||||
inst(&mut cpu, idata_build.build()).unwrap();
|
inst(&mut cpu, idata_build.build()).unwrap();
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{
|
use pimcore::{
|
||||||
Executable,
|
Executable,
|
||||||
cpu::CPU,
|
cpu::crossbar::Crossbar,
|
||||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// VVADD Test
|
/// VVADD Test
|
||||||
@@ -11,7 +13,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -115,7 +117,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -219,7 +221,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -323,7 +325,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -420,7 +422,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
9.0.into(),
|
9.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -524,7 +526,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
9.0.into(),
|
9.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -562,6 +564,7 @@ where
|
|||||||
vavg,
|
vavg,
|
||||||
idata_build
|
idata_build
|
||||||
.set_rdr1r2(3, 1, 1)
|
.set_rdr1r2(3, 1, 1)
|
||||||
|
.set_offset_select(1)
|
||||||
.set_imm_len(8 * size_of::<F>() as i32)
|
.set_imm_len(8 * size_of::<F>() as i32)
|
||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
@@ -617,7 +620,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
(-9.0).into(),
|
(-9.0).into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -717,7 +720,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
0.1.into(),
|
0.1.into(),
|
||||||
0.2.into(),
|
0.2.into(),
|
||||||
@@ -819,7 +822,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
0.1.into(),
|
0.1.into(),
|
||||||
0.2.into(),
|
0.2.into(),
|
||||||
@@ -923,9 +926,6 @@ where
|
|||||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
|
||||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
|
||||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
|
||||||
let matrix: [M; _] = [
|
let matrix: [M; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -944,7 +944,10 @@ where
|
|||||||
15.0.into(),
|
15.0.into(),
|
||||||
16.0.into(),
|
16.0.into(),
|
||||||
];
|
];
|
||||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||||
|
crossbar.execute_store(&matrix).unwrap();
|
||||||
|
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||||
|
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||||
let vector: [F; _] = [
|
let vector: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
|||||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{
|
use pimcore::{
|
||||||
Executable, CoreInstructionsBuilder,
|
Executable, CoreInstructionsBuilder,
|
||||||
cpu::CPU,
|
|
||||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ld_test() {
|
fn ld_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn st_test() {
|
fn st_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -76,7 +77,7 @@ fn st_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lldi_test() {
|
fn lldi_test() {
|
||||||
let cpu = CPU::new(1);
|
let cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lmv_test() {
|
fn lmv_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn simple_send_recv_test() {
|
fn simple_send_recv_test() {
|
||||||
let mut cpu = CPU::new(2);
|
let mut cpu = common::empty_cpu(2);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiple_send_recv_test() {
|
fn multiple_send_recv_test() {
|
||||||
let mut cpu = CPU::new(4);
|
let mut cpu = common::empty_cpu(4);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 1.0, 1.0, 1.0, 1.0
|
1.0, 1.0, 1.0, 1.0, 1.0
|
||||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
|||||||
];
|
];
|
||||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||||
|
|
||||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(from).fix_core_indx();
|
idata_build.set_core_indx(from).fix_core_indx();
|
||||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(to).fix_core_indx();
|
idata_build.set_core_indx(to).fix_core_indx();
|
||||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
|||||||
|
|
||||||
|
|
||||||
// 1 -> 3
|
// 1 -> 3
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
send_inst(&mut inst_builder, 1, 3);
|
||||||
core_instruction_builder.set_core(1, inst_builder.build());
|
core_instruction_builder.set_core(1, inst_builder.build());
|
||||||
|
|
||||||
// 2 -> 3
|
// 2 -> 3
|
||||||
// 2 <- 4
|
// 2 <- 4
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
send_inst(&mut inst_builder, 2, 3);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
recv_inst(&mut inst_builder, 2, 4);
|
||||||
core_instruction_builder.set_core(2, inst_builder.build());
|
core_instruction_builder.set_core(2, inst_builder.build());
|
||||||
|
|
||||||
// 3 <- 2
|
// 3 <- 2
|
||||||
// 3 <- 4
|
// 3 <- 4
|
||||||
// 3 <- 1
|
// 3 <- 1
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
recv_inst(&mut inst_builder, 3, 2);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
recv_inst(&mut inst_builder, 3, 4);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
recv_inst(&mut inst_builder, 3, 1);
|
||||||
core_instruction_builder.set_core(3, inst_builder.build());
|
core_instruction_builder.set_core(3, inst_builder.build());
|
||||||
// 4 -> 2
|
// 4 -> 2
|
||||||
// 4 -> 3
|
// 4 -> 3
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
send_inst(&mut inst_builder, 4, 2);
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
send_inst(&mut inst_builder, 4, 3);
|
||||||
core_instruction_builder.set_core(4, inst_builder.build());
|
core_instruction_builder.set_core(4, inst_builder.build());
|
||||||
|
|
||||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||||
|
|||||||
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
@@ -68,5 +68,6 @@ add_pim_library(OMPIMAccel
|
|||||||
OMSpatialToPim
|
OMSpatialToPim
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
|
OMPimStaticMemoryCoalescing
|
||||||
MLIRTensorInferTypeOpInterfaceImpl
|
MLIRTensorInferTypeOpInterfaceImpl
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,15 @@
|
|||||||
add_pim_library(OMPimCommon
|
add_pim_library(OMPimCommon
|
||||||
PimCommon.cpp
|
IR/AddressAnalysis.cpp
|
||||||
|
IR/ConstantUtils.cpp
|
||||||
|
IR/CoreBlockUtils.cpp
|
||||||
|
IR/EntryPointUtils.cpp
|
||||||
|
IR/ShapeUtils.cpp
|
||||||
|
IR/SubviewUtils.cpp
|
||||||
|
IR/WeightUtils.cpp
|
||||||
|
Support/DebugDump.cpp
|
||||||
|
Support/Diagnostics.cpp
|
||||||
|
Support/FileSystemUtils.cpp
|
||||||
|
Support/ReportUtils.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,313 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (!moduleOp || !getGlobalOp)
|
||||||
|
return {};
|
||||||
|
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (!knowledge)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto iter = knowledge->aliases.find(value);
|
||||||
|
while (iter != knowledge->aliases.end()) {
|
||||||
|
value = iter->second;
|
||||||
|
iter = knowledge->aliases.find(value);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (mlir::isa<mlir::BlockArgument>(value))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
||||||
|
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||||
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||||
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||||
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
|
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||||
|
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto elementType = denseAttr.getElementType();
|
||||||
|
if (!elementType.isIndex() && !elementType.isInteger())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> indices;
|
||||||
|
indices.reserve(loadOp.getIndices().size());
|
||||||
|
for (mlir::Value index : loadOp.getIndices()) {
|
||||||
|
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||||
|
if (failed(resolvedIndex))
|
||||||
|
return mlir::failure();
|
||||||
|
indices.push_back(*resolvedIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||||
|
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||||
|
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (knowledge) {
|
||||||
|
auto iter = knowledge->indexValues.find(value);
|
||||||
|
if (iter != knowledge->indexValues.end())
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
||||||
|
if (constantOp) {
|
||||||
|
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||||
|
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||||
|
|
||||||
|
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs + *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs - *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return *lhs * *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||||
|
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||||
|
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
||||||
|
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||||
|
if (!integerAttr)
|
||||||
|
return mlir::failure();
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (mlir::isa<mlir::BlockArgument>(value))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
mlir::Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||||
|
if (!tiedOperand)
|
||||||
|
return mlir::failure();
|
||||||
|
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||||
|
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value = yieldedValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||||
|
strides.reserve(subviewOp.getMixedStrides().size());
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||||
|
if (failed(resolvedOffset))
|
||||||
|
return mlir::failure();
|
||||||
|
offsets.push_back(*resolvedOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||||
|
if (failed(resolvedSize))
|
||||||
|
return mlir::failure();
|
||||||
|
sizes.push_back(*resolvedSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||||
|
if (failed(resolvedStride))
|
||||||
|
return mlir::failure();
|
||||||
|
strides.push_back(*resolvedStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||||
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||||
|
value = resolveAlias(castOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveIndexValueImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||||
|
return resolveContiguousAddressImpl(value, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveContiguousAddressImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Describes a value as a base addressable object plus a statically known
|
||||||
|
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
||||||
|
struct ResolvedContiguousAddress {
|
||||||
|
mlir::Value base;
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Records compile-time facts used when interpreting address arithmetic and
|
||||||
|
/// loop-carried aliases inside PIM regions.
|
||||||
|
struct StaticValueKnowledge {
|
||||||
|
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||||
|
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||||
|
|
||||||
|
StaticValueKnowledge() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
|
/// Resolves a value to contiguous backing storage when that storage can be
|
||||||
|
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Statically evaluates index-like SSA values, including simple integer
|
||||||
|
/// arithmetic and loop facts recorded in `knowledge`.
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||||
|
/// loop-carried memref/result.
|
||||||
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,745 @@
|
|||||||
|
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||||
|
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||||
|
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace compact_asm {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
enum class ListDelimiter {
|
||||||
|
Square,
|
||||||
|
Paren
|
||||||
|
};
|
||||||
|
|
||||||
|
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||||
|
if (delimiter == ListDelimiter::Square)
|
||||||
|
return parser.parseLSquare();
|
||||||
|
return parser.parseLParen();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||||
|
if (delimiter == ListDelimiter::Square)
|
||||||
|
return parser.parseOptionalRSquare();
|
||||||
|
return parser.parseOptionalRParen();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT>
|
||||||
|
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||||
|
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT>
|
||||||
|
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||||
|
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename EntryT, typename ParseEntryFn>
|
||||||
|
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<EntryT>& entries,
|
||||||
|
ParseEntryFn parseEntry) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
EntryT entry;
|
||||||
|
if (parseEntry(entry))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
entries.push_back(entry);
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<IntT> subgroup;
|
||||||
|
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(values, subgroup);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int64_t first = 0;
|
||||||
|
if (parser.parseInteger(first))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
int64_t last = 0;
|
||||||
|
if (parser.parseInteger(last) || last < first)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||||
|
|
||||||
|
int64_t step = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||||
|
if (parser.parseInteger(step) || step <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||||
|
}
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
if ((last - first) % step != 0) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"range end must be reachable from start using the given step");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t value = first; value <= last; value += step)
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
values.push_back(static_cast<IntT>(value));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
values.push_back(static_cast<IntT>(first));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RangeT, typename PrintEntryFn>
|
||||||
|
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||||
|
for (size_t index = 0; index < entries.size();) {
|
||||||
|
size_t runEnd = index + 1;
|
||||||
|
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||||
|
++runEnd;
|
||||||
|
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printEntry(entries[index]);
|
||||||
|
size_t runLength = runEnd - index;
|
||||||
|
if (runLength > 1)
|
||||||
|
printer << " x" << runLength;
|
||||||
|
index = runEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT, typename IntT>
|
||||||
|
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
|
||||||
|
struct FlatCompression {
|
||||||
|
enum class Kind {
|
||||||
|
Single,
|
||||||
|
EqualRun,
|
||||||
|
Progression
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = Kind::Single;
|
||||||
|
size_t covered = 1;
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
size_t progressionValueCount = 1;
|
||||||
|
int64_t step = 1;
|
||||||
|
IntT firstValue {};
|
||||||
|
IntT lastValue {};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto computeFlatCompression = [&](size_t start) {
|
||||||
|
FlatCompression compression;
|
||||||
|
compression.firstValue = values[start];
|
||||||
|
compression.lastValue = values[start];
|
||||||
|
|
||||||
|
auto findEqualRunEnd = [&](size_t runStart) {
|
||||||
|
size_t runEnd = runStart + 1;
|
||||||
|
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||||
|
++runEnd;
|
||||||
|
return runEnd;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t firstRunEnd = findEqualRunEnd(start);
|
||||||
|
compression.repeatCount = firstRunEnd - start;
|
||||||
|
size_t progressionEnd = firstRunEnd;
|
||||||
|
int64_t step = 0;
|
||||||
|
IntT lastValue = values[start];
|
||||||
|
|
||||||
|
if (firstRunEnd < values.size()) {
|
||||||
|
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||||
|
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||||
|
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||||
|
progressionEnd = secondRunEnd;
|
||||||
|
lastValue = values[firstRunEnd];
|
||||||
|
size_t currentRunStart = secondRunEnd;
|
||||||
|
while (currentRunStart < values.size()) {
|
||||||
|
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||||
|
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||||
|
break;
|
||||||
|
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||||
|
break;
|
||||||
|
lastValue = values[currentRunStart];
|
||||||
|
progressionEnd = currentRunEnd;
|
||||||
|
currentRunStart = currentRunEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
step = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compression.covered = 1;
|
||||||
|
if (progressionEnd > firstRunEnd) {
|
||||||
|
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||||
|
if (progressionValueCount >= 3) {
|
||||||
|
compression.kind = FlatCompression::Kind::Progression;
|
||||||
|
compression.covered = progressionEnd - start;
|
||||||
|
compression.progressionValueCount = progressionValueCount;
|
||||||
|
compression.step = step;
|
||||||
|
compression.lastValue = lastValue;
|
||||||
|
return compression;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (compression.repeatCount > 1) {
|
||||||
|
compression.kind = FlatCompression::Kind::EqualRun;
|
||||||
|
compression.covered = compression.repeatCount;
|
||||||
|
return compression;
|
||||||
|
}
|
||||||
|
|
||||||
|
return compression;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto findRepeatedSublist = [&](size_t start) {
|
||||||
|
size_t bestLength = 0;
|
||||||
|
size_t bestRepeatCount = 1;
|
||||||
|
size_t remaining = values.size() - start;
|
||||||
|
|
||||||
|
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||||
|
while (start + (repeatCount + 1) * length <= values.size()
|
||||||
|
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||||
|
++repeatCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (repeatCount <= 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t covered = length * repeatCount;
|
||||||
|
size_t bestCovered = bestLength * bestRepeatCount;
|
||||||
|
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||||
|
bestLength = length;
|
||||||
|
bestRepeatCount = repeatCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::pair(bestLength, bestRepeatCount);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t index = 0; index < values.size();) {
|
||||||
|
if (index != 0)
|
||||||
|
stream << ", ";
|
||||||
|
|
||||||
|
FlatCompression flat = computeFlatCompression(index);
|
||||||
|
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||||
|
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||||
|
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||||
|
printOpenDelimiter(stream, ListDelimiter::Paren);
|
||||||
|
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
|
||||||
|
printCloseDelimiter(stream, ListDelimiter::Paren);
|
||||||
|
stream << " x" << sublistRepeatCount;
|
||||||
|
index += repeatedSublistCoverage;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (flat.kind) {
|
||||||
|
case FlatCompression::Kind::Progression:
|
||||||
|
stream << flat.firstValue << " to " << flat.lastValue;
|
||||||
|
if (flat.step != 1)
|
||||||
|
stream << " by " << flat.step;
|
||||||
|
if (flat.repeatCount > 1)
|
||||||
|
stream << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::EqualRun:
|
||||||
|
stream << flat.firstValue << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::Single:
|
||||||
|
stream << flat.firstValue;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename StreamT, typename IntT>
|
||||||
|
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(stream, delimiter);
|
||||||
|
printCompressedIntegerEntries(stream, values);
|
||||||
|
printCloseDelimiter(stream, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||||
|
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||||
|
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||||
|
for (size_t index = 0; index < values.size();) {
|
||||||
|
size_t equalRunEnd = index + 1;
|
||||||
|
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||||
|
++equalRunEnd;
|
||||||
|
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
if (equalRunEnd - index > 1) {
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
printer << " x" << (equalRunEnd - index);
|
||||||
|
index = equalRunEnd;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t rangeEnd = index + 1;
|
||||||
|
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||||
|
while (rangeEnd < values.size()) {
|
||||||
|
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||||
|
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||||
|
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||||
|
while (rangeEnd < values.size()) {
|
||||||
|
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||||
|
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||||
|
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
if (rangeEnd - index >= 3) {
|
||||||
|
printer << " to ";
|
||||||
|
printer.printOperand(values[rangeEnd - 1]);
|
||||||
|
}
|
||||||
|
else if (rangeEnd - index == 2) {
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(values[index + 1]);
|
||||||
|
}
|
||||||
|
index = rangeEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||||
|
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printCompressedValueSequence(printer, values);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printCompressedTypeSequence(printer, types);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||||
|
Type firstType;
|
||||||
|
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||||
|
if (!firstTypeResult.has_value()) {
|
||||||
|
if (allowEmpty)
|
||||||
|
return success();
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||||
|
}
|
||||||
|
if (failed(*firstTypeResult))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto appendType = [&](Type type) -> ParseResult {
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
types.push_back(type);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (appendType(firstType))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
Type nextType;
|
||||||
|
if (parser.parseType(nextType) || appendType(nextType))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||||
|
OpAsmParser::UnresolvedOperand firstOperand,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
OpAsmParser::UnresolvedOperand lastOperand;
|
||||||
|
if (parser.parseOperand(lastOperand))
|
||||||
|
return failure();
|
||||||
|
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||||
|
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||||
|
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
|
operands.push_back(firstOperand);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
OpAsmParser::UnresolvedOperand firstOperand;
|
||||||
|
if (parser.parseOperand(firstOperand))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
|
||||||
|
return failure();
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||||
|
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||||
|
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (values.size() / tupleSize);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printType(types[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (types.size() / tupleSize);
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleOperands.clear();
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
}
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult
|
||||||
|
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<Type> tupleTypes;
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleTypes.clear();
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
}
|
||||||
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
Type type;
|
||||||
|
if (parser.parseType(type))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
types.push_back(type);
|
||||||
|
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||||
|
if (block.getNumArguments() == 0) {
|
||||||
|
printer << "() = ()";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block.getNumArguments() == 1) {
|
||||||
|
printer.printOperand(block.getArgument(0));
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||||
|
OpAsmParser::Argument firstArgument,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
OpAsmParser::Argument lastArgument;
|
||||||
|
if (parser.parseArgument(lastArgument))
|
||||||
|
return failure();
|
||||||
|
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||||
|
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||||
|
}
|
||||||
|
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||||
|
arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
arguments.push_back(firstArgument);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||||
|
argument.type = inputType;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
if (succeeded(parser.parseOptionalRParen())) {
|
||||||
|
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseRParen() || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
arguments.push_back(argument);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace compact_asm
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
|
#include "ConstantUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||||
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
|
||||||
|
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||||
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||||
|
return current->getBlock();
|
||||||
|
|
||||||
|
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||||
|
return &funcOp.getBody().front();
|
||||||
|
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||||
|
return moduleOp.getBody();
|
||||||
|
return anchorOp->getBlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
||||||
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||||
|
for (Operation& op : *hostBlock) {
|
||||||
|
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||||
|
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||||
|
continue;
|
||||||
|
return constantOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||||
|
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||||
|
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||||
|
mlir::Attribute value,
|
||||||
|
mlir::Type type,
|
||||||
|
mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||||
|
return mlir::isa<mlir::arith::ConstantOp,
|
||||||
|
mlir::arith::AddIOp,
|
||||||
|
mlir::arith::SubIOp,
|
||||||
|
mlir::arith::MulIOp,
|
||||||
|
mlir::arith::DivUIOp,
|
||||||
|
mlir::arith::MinUIOp,
|
||||||
|
mlir::arith::RemUIOp,
|
||||||
|
mlir::arith::IndexCastOp,
|
||||||
|
mlir::memref::AllocOp,
|
||||||
|
mlir::memref::SubViewOp,
|
||||||
|
mlir::memref::CastOp,
|
||||||
|
mlir::memref::CollapseShapeOp,
|
||||||
|
mlir::memref::ExpandShapeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
walkPimCoreBlock(mlir::Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
for (mlir::Operation& op : block) {
|
||||||
|
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||||
|
continue;
|
||||||
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||||
|
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||||
|
mlir::Block& loopBody = forOp.getRegion().front();
|
||||||
|
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||||
|
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||||
|
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||||
|
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||||
|
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||||
|
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||||
|
StaticValueKnowledge loopKnowledge = knowledge;
|
||||||
|
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||||
|
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||||
|
loopKnowledge.aliases[iterArg] = iterValue;
|
||||||
|
|
||||||
|
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||||
|
hasFailure = true;
|
||||||
|
|
||||||
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
|
||||||
|
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||||
|
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(callback(op, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
return mlir::success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns true for ops in a `pim.core` body that only participate in static
|
||||||
|
/// address or index computation and therefore do not emit PIM instructions.
|
||||||
|
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
|
||||||
|
/// their bounds are known and invoking `callback` only on instruction-emitting
|
||||||
|
/// operations.
|
||||||
|
mlir::LogicalResult
|
||||||
|
walkPimCoreBlock(mlir::Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
||||||
|
if (!moduleOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
||||||
|
if (entryPoints.size() > 1) {
|
||||||
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
if (!entryPoints.empty()) {
|
||||||
|
auto entryPointAttr =
|
||||||
|
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||||
|
if (!entryPointAttr) {
|
||||||
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||||
|
if (!entryFunc) {
|
||||||
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||||
|
<< entryPointAttr.getLeafReference().getValue();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
return entryFunc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
||||||
|
return mainGraphFunc;
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
||||||
|
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
||||||
|
if (!funcOp.isExternal())
|
||||||
|
nonExternalFuncs.push_back(funcOp);
|
||||||
|
if (nonExternalFuncs.size() == 1)
|
||||||
|
return nonExternalFuncs.front();
|
||||||
|
|
||||||
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Resolves the function the PIM pipeline should treat as its entry point.
|
||||||
|
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
|
||||||
|
/// non-external function if the module is otherwise unambiguous.
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
||||||
|
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
||||||
|
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
indices[dim] = linearIndex / stride;
|
||||||
|
linearIndex %= stride;
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||||
|
linearIndex += index * stride;
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
numElements *= dim;
|
||||||
|
return numElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides) {
|
||||||
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||||
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstNonZeroOffset = std::find_if(
|
||||||
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return offset != 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||||
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||||
|
if (size > dimension - offset)
|
||||||
|
return false;
|
||||||
|
++firstNonZeroOffset;
|
||||||
|
|
||||||
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != sizesAndShape.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
|
||||||
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value stripMemRefCasts(Value value) {
|
||||||
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
|
value = castOp.getSource();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value stripMemRefViewOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||||
|
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
value = stripMemRefViewOps(value);
|
||||||
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
StaticSubviewInfo info;
|
||||||
|
info.source = source;
|
||||||
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||||
|
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto staticSize = getConstantIntValue(size);
|
||||||
|
if (!staticSize)
|
||||||
|
return failure();
|
||||||
|
info.sizes.push_back(*staticSize);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
info.strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||||
|
SmallVector<int64_t> staticOffsets;
|
||||||
|
staticOffsets.reserve(info.offsets.size());
|
||||||
|
for (OpFoldResult offset : info.offsets) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
staticOffsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
return staticOffsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct StaticSubviewInfo {
|
||||||
|
mlir::Value source;
|
||||||
|
llvm::SmallVector<int64_t> sourceShape;
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||||
|
|
||||||
|
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||||
|
|
||||||
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
|
|
||||||
|
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||||
|
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||||
|
|
||||||
|
void markWeightAlways(mlir::Operation* op) {
|
||||||
|
assert(op && "expected valid op");
|
||||||
|
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
|
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||||
|
bool found = false;
|
||||||
|
parentOp.walk([&](mlir::Operation* op) {
|
||||||
|
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||||
|
found |= mvmOp.getWeight() == weightArg;
|
||||||
|
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||||
|
found |= vmmOp.getWeight() == weightArg;
|
||||||
|
});
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
|
auto weights = parentOp.getWeights();
|
||||||
|
llvm::SmallSet<unsigned, 8> visited;
|
||||||
|
auto walkWeight = [&](mlir::Value weight) {
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||||
|
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||||
|
continue;
|
||||||
|
if (visited.insert(weightIndex).second)
|
||||||
|
callback(parentOp->getOpOperand(weightIndex));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||||
|
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||||
|
mlir::Operation* user = use.getOwner();
|
||||||
|
unsigned operandIndex = use.getOperandNumber();
|
||||||
|
|
||||||
|
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
||||||
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||||
|
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
||||||
|
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
||||||
|
if (!visited.insert(currentValue).second)
|
||||||
|
return true;
|
||||||
|
if (currentValue.use_empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
||||||
|
if (isSpatialMvmVmmWeightUse(use))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
mlir::Operation* user = use.getOwner();
|
||||||
|
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
||||||
|
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||||
|
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
||||||
|
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||||
|
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||||
|
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||||
|
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
||||||
|
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return walkUses(value, walkUses);
|
||||||
|
}
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
|
assert(root && "expected valid root op");
|
||||||
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||||
|
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||||
|
callback(coreOp->getOpOperand(weightIndex));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||||
|
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||||
|
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||||
|
/// weight across later PIM lowering/codegen stages.
|
||||||
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||||
|
|
||||||
|
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
|
||||||
|
/// allowing later passes to preserve it as a dedicated weight-like object.
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||||
|
|
||||||
|
/// Visits weight operands consumed by Pim core ops/core batches so downstream
|
||||||
|
/// passes can identify globals that must remain weight-backed.
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,546 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <filesystem>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
std::string getOutputDir() {
|
|
||||||
if (outputBaseName.empty() || outputBaseName == "-")
|
|
||||||
return {};
|
|
||||||
|
|
||||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
|
||||||
if (lastSlash == std::string::npos)
|
|
||||||
return ".";
|
|
||||||
return outputBaseName.substr(0, lastSlash);
|
|
||||||
}
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory) {
|
|
||||||
std::error_code errorCode;
|
|
||||||
std::filesystem::create_directories(directory, errorCode);
|
|
||||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
|
||||||
}
|
|
||||||
|
|
||||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
|
||||||
std::string outputDir = getOutputDir();
|
|
||||||
if (outputDir.empty())
|
|
||||||
return;
|
|
||||||
|
|
||||||
std::string dialectsDir = outputDir + "/dialects";
|
|
||||||
createDirectory(dialectsDir);
|
|
||||||
|
|
||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *moduleOp;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|
||||||
if (!moduleOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
|
||||||
if (entryPoints.size() > 1) {
|
|
||||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (!entryPoints.empty()) {
|
|
||||||
auto entryPointAttr =
|
|
||||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
|
||||||
if (!entryPointAttr) {
|
|
||||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
|
||||||
if (!entryFunc) {
|
|
||||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
|
||||||
<< entryPointAttr.getLeafReference().getValue();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return entryFunc;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
|
||||||
return mainGraphFunc;
|
|
||||||
|
|
||||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
|
||||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
|
||||||
if (!funcOp.isExternal())
|
|
||||||
nonExternalFuncs.push_back(funcOp);
|
|
||||||
if (nonExternalFuncs.size() == 1)
|
|
||||||
return nonExternalFuncs.front();
|
|
||||||
|
|
||||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
|
||||||
|
|
||||||
void markWeightAlways(Operation* op) {
|
|
||||||
assert(op && "expected valid op");
|
|
||||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
|
||||||
}
|
|
||||||
|
|
||||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
|
||||||
if (!moduleOp || !getGlobalOp)
|
|
||||||
return {};
|
|
||||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
// channelNewOp should have two users: `op` and a
|
|
||||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
|
||||||
auto channelUsers = channelNewOp->getUsers();
|
|
||||||
auto usersIterator = channelUsers.begin();
|
|
||||||
auto firstUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator == channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"only one found.");
|
|
||||||
channelNewOp->dump();
|
|
||||||
op->dump();
|
|
||||||
channelNewOp->getParentOp()->dump();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto secondUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator != channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"more than two found.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
Operation* notOpUser;
|
|
||||||
if (firstUser == op) {
|
|
||||||
notOpUser = secondUser;
|
|
||||||
}
|
|
||||||
else if (secondUser == op) {
|
|
||||||
notOpUser = firstUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"and one of them must be me, but"
|
|
||||||
"none of them is actually me.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opIsReceive) {
|
|
||||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelSendOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelReceiveOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
||||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
|
||||||
SmallVector<int64_t> indices(shape.size(), 0);
|
|
||||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
||||||
indices[dim] = linearIndex / stride;
|
|
||||||
linearIndex %= stride;
|
|
||||||
}
|
|
||||||
return indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
|
||||||
int64_t linearIndex = 0;
|
|
||||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
||||||
linearIndex += index * stride;
|
|
||||||
return linearIndex;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
|
||||||
int64_t numElements = 1;
|
|
||||||
for (int64_t dim : shape)
|
|
||||||
numElements *= dim;
|
|
||||||
return numElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
|
||||||
ArrayRef<int64_t> offsets,
|
|
||||||
ArrayRef<int64_t> sizes,
|
|
||||||
ArrayRef<int64_t> strides) {
|
|
||||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
||||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstNonZeroOffset = std::find_if(
|
|
||||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return offset != 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
||||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
||||||
if (size > dimension - offset)
|
|
||||||
return false;
|
|
||||||
++firstNonZeroOffset;
|
|
||||||
|
|
||||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstDifferentSize != sizesAndShape.end()) {
|
|
||||||
++firstDifferentSize;
|
|
||||||
|
|
||||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, _dimension] = sizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (!knowledge)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
auto iter = knowledge->aliases.find(value);
|
|
||||||
while (iter != knowledge->aliases.end()) {
|
|
||||||
value = iter->second;
|
|
||||||
iter = knowledge->aliases.find(value);
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
|
||||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
|
||||||
// and when propagating yielded values across iterations during static unrolling.
|
|
||||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
|
||||||
return value;
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
if (auto result = dyn_cast<OpResult>(value))
|
|
||||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
|
||||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
|
||||||
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (knowledge) {
|
|
||||||
auto iter = knowledge->indexValues.find(value);
|
|
||||||
if (iter != knowledge->indexValues.end())
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
|
||||||
if (constantOp) {
|
|
||||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
|
||||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
|
||||||
|
|
||||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs + *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs - *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return failure();
|
|
||||||
return *lhs * *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
|
||||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
|
||||||
if (!integerAttr)
|
|
||||||
return failure();
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
|
||||||
const StaticValueKnowledge* knowledge) {
|
|
||||||
int64_t byteOffset = 0;
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (isa<BlockArgument>(value))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
|
||||||
if (!tiedOperand)
|
|
||||||
return failure();
|
|
||||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
|
||||||
auto result = dyn_cast<OpResult>(value);
|
|
||||||
if (!result)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
|
||||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
|
||||||
// (every iteration yields the same backing memory in the DPS sense).
|
|
||||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
||||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
|
||||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
||||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value = yieldedValue;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
|
||||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> offsets;
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
SmallVector<int64_t> strides;
|
|
||||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
||||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
|
||||||
strides.reserve(subviewOp.getMixedStrides().size());
|
|
||||||
|
|
||||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
||||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
|
||||||
if (failed(resolvedOffset))
|
|
||||||
return failure();
|
|
||||||
offsets.push_back(*resolvedOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
|
||||||
if (failed(resolvedSize))
|
|
||||||
return failure();
|
|
||||||
sizes.push_back(*resolvedSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
|
||||||
if (failed(resolvedStride))
|
|
||||||
return failure();
|
|
||||||
strides.push_back(*resolvedStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
|
||||||
value = resolveAlias(castOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
|
||||||
|
|
||||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveIndexValueImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
|
||||||
return resolveContiguousAddressImpl(value, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveContiguousAddressImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isCoreStaticAddressOp(Operation* op) {
|
|
||||||
return isa<arith::ConstantOp,
|
|
||||||
arith::AddIOp,
|
|
||||||
arith::SubIOp,
|
|
||||||
arith::MulIOp,
|
|
||||||
arith::DivUIOp,
|
|
||||||
arith::RemUIOp,
|
|
||||||
arith::IndexCastOp,
|
|
||||||
memref::AllocOp,
|
|
||||||
memref::SubViewOp,
|
|
||||||
memref::CastOp,
|
|
||||||
memref::CollapseShapeOp,
|
|
||||||
memref::ExpandShapeOp>(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult walkPimCoreBlock(Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
|
||||||
bool hasFailure = false;
|
|
||||||
for (Operation& op : block) {
|
|
||||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
|
||||||
Block& loopBody = forOp.getRegion().front();
|
|
||||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
|
||||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
|
||||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
|
||||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
|
||||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
|
||||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
|
||||||
StaticValueKnowledge loopKnowledge = knowledge;
|
|
||||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
|
||||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
|
||||||
loopKnowledge.aliases[iterArg] = iterValue;
|
|
||||||
|
|
||||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
|
||||||
hasFailure = true;
|
|
||||||
|
|
||||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
|
||||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
|
||||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(callback(op, knowledge)))
|
|
||||||
hasFailure = true;
|
|
||||||
}
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -7,82 +7,23 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct ResolvedContiguousAddress {
|
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
|
||||||
mlir::Value base;
|
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
|
||||||
int64_t byteOffset = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct StaticValueKnowledge {
|
|
||||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
|
||||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
|
||||||
|
|
||||||
StaticValueKnowledge() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string getOutputDir();
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory);
|
|
||||||
|
|
||||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
|
||||||
|
|
||||||
bool hasWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
void markWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation*>
|
|
||||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t>
|
|
||||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
|
||||||
llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
||||||
const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
|
||||||
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
|
||||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
|
||||||
/// only contribute to static addressing or index computations (arith integer math,
|
|
||||||
/// memref view ops, memref.alloc, arith.constant).
|
|
||||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
|
||||||
|
|
||||||
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
|
||||||
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
|
||||||
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
|
||||||
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
|
||||||
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
|
||||||
mlir::LogicalResult
|
|
||||||
walkPimCoreBlock(mlir::Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/dialects";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
|
llvm::raw_os_ostream os(file);
|
||||||
|
mlir::OpPrintingFlags flags;
|
||||||
|
flags.elideLargeElementsAttrs();
|
||||||
|
moduleOp.print(os, flags);
|
||||||
|
os.flush();
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Emits a MLIR snapshot under the current compiler output
|
||||||
|
/// directory for pass-level debugging.
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
|
||||||
|
return op->emitOpError() << "requires statically shaped " << valueDescription;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||||
|
llvm::StringRef valueDescription,
|
||||||
|
int64_t actualRank,
|
||||||
|
llvm::ArrayRef<int64_t> supportedRanks) {
|
||||||
|
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
|
||||||
|
if (supportedRanks.empty())
|
||||||
|
return diag;
|
||||||
|
|
||||||
|
diag << "; supported rank";
|
||||||
|
if (supportedRanks.size() != 1)
|
||||||
|
diag << 's';
|
||||||
|
diag << ' ';
|
||||||
|
|
||||||
|
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
|
||||||
|
return diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::InFlightDiagnostic
|
||||||
|
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
|
||||||
|
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
|
||||||
|
llvm::StringRef action,
|
||||||
|
llvm::StringRef path,
|
||||||
|
const std::error_code& errorCode) {
|
||||||
|
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <system_error>
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
struct CappedDiagnosticReporter {
|
||||||
|
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||||
|
|
||||||
|
template <typename EmitFn>
|
||||||
|
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||||
|
numFailures++;
|
||||||
|
if (numFailures <= maxReportedFailures)
|
||||||
|
emit(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||||
|
if (numFailures > maxReportedFailures)
|
||||||
|
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||||
|
<< failureDescription;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasFailure() const { return numFailures != 0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t maxReportedFailures;
|
||||||
|
int64_t numFailures = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
|
||||||
|
/// accepted by the current lowering/codegen path.
|
||||||
|
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||||
|
llvm::StringRef valueDescription,
|
||||||
|
int64_t actualRank,
|
||||||
|
llvm::ArrayRef<int64_t> supportedRanks);
|
||||||
|
|
||||||
|
/// Emits a consistent diagnostic for missing symbol/global references.
|
||||||
|
mlir::InFlightDiagnostic
|
||||||
|
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
|
||||||
|
|
||||||
|
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
|
||||||
|
/// the relevant IR location.
|
||||||
|
mlir::LogicalResult
|
||||||
|
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
|
||||||
|
return mlir::success(succeeded(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::string getOutputDir() {
|
||||||
|
if (outputBaseName.empty() || outputBaseName == "-")
|
||||||
|
return {};
|
||||||
|
|
||||||
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||||
|
if (lastSlash == std::string::npos)
|
||||||
|
return ".";
|
||||||
|
return outputBaseName.substr(0, lastSlash);
|
||||||
|
}
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::filesystem::create_directories(directory, errorCode);
|
||||||
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns the directory that should hold PIM artifacts/debug dumps for the
|
||||||
|
/// current compiler invocation.
|
||||||
|
std::string getOutputDir();
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
#include "llvm/Support/Format.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::fstream openReportFile(const std::string& name) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
std::string reportsDir = outputDir + "/reports";
|
||||||
|
createDirectory(reportsDir);
|
||||||
|
return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string formatReportMemory(uint64_t bytes) {
|
||||||
|
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||||
|
int i = 0;
|
||||||
|
double size = static_cast<double>(bytes);
|
||||||
|
while (size >= 1024 && i < 6) {
|
||||||
|
size /= 1024;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string out;
|
||||||
|
llvm::raw_string_ostream rss(out);
|
||||||
|
rss << llvm::format("%.2f ", size) << units[i];
|
||||||
|
return rss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||||
|
for (const ReportField& field : fields)
|
||||||
|
os << "\t" << field.label << ": " << field.value << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
|
||||||
|
os << "\t" << title << ":\n";
|
||||||
|
for (const ReportField& field : fields)
|
||||||
|
os << "\t " << field.label << ": " << field.value << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||||
|
os << "Totals:\n";
|
||||||
|
for (const ReportField& field : fields)
|
||||||
|
os << "\t" << field.label << ": " << field.value << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||||
|
llvm::ArrayRef<ReportField> perCoreFields,
|
||||||
|
llvm::ArrayRef<ReportField> totalFields) {
|
||||||
|
printReportFieldBlock(os, "Per core", perCoreFields);
|
||||||
|
printReportFieldBlock(os, "Total", totalFields);
|
||||||
|
}
|
||||||
|
|
||||||
|
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
|
||||||
|
if (hasNextEntry)
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <limits>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::fstream openReportFile(const std::string& name);
|
||||||
|
std::string formatReportMemory(uint64_t bytes);
|
||||||
|
|
||||||
|
struct ReportField {
|
||||||
|
std::string label;
|
||||||
|
std::string value;
|
||||||
|
};
|
||||||
|
|
||||||
|
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||||
|
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
|
||||||
|
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||||
|
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||||
|
llvm::ArrayRef<ReportField> perCoreFields,
|
||||||
|
llvm::ArrayRef<ReportField> totalFields);
|
||||||
|
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
|
||||||
|
|
||||||
|
template <typename EntryTy>
|
||||||
|
int32_t getFirstReportCoreId(const EntryTy& entry) {
|
||||||
|
if (entry.coreIds.empty())
|
||||||
|
return std::numeric_limits<int32_t>::max();
|
||||||
|
return entry.coreIds.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename EntryRange>
|
||||||
|
void sortReportEntriesByFirstCore(EntryRange& entries) {
|
||||||
|
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
|
||||||
|
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
|
||||||
|
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
|
||||||
|
if (lhsFirstCore != rhsFirstCore)
|
||||||
|
return lhsFirstCore < rhsFirstCore;
|
||||||
|
return lhs.id < rhs.id;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
|
|||||||
|
|
||||||
add_pim_library(OMPimCompilerUtils
|
add_pim_library(OMPimCompilerUtils
|
||||||
PimCompilerUtils.cpp
|
PimCompilerUtils.cpp
|
||||||
|
PimArtifactWriter.cpp
|
||||||
|
PimBatchEmission.cpp
|
||||||
PimCodeGen.cpp
|
PimCodeGen.cpp
|
||||||
|
PimWeightEmitter.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
@@ -26,6 +29,7 @@ add_pim_library(OMPimCompilerUtils
|
|||||||
OMPimCompilerOptions
|
OMPimCompilerOptions
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
|
OMPimStaticMemoryCoalescing
|
||||||
OMPimPasses
|
OMPimPasses
|
||||||
OMONNXToSpatial
|
OMONNXToSpatial
|
||||||
OMSpatialToPim
|
OMSpatialToPim
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes
|
||||||
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
|
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||||
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (hasWeightAlways(getGlobalOp))
|
||||||
|
return;
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp)
|
||||||
|
return;
|
||||||
|
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||||
|
return;
|
||||||
|
auto initialValue = globalOp.getInitialValue();
|
||||||
|
if (!initialValue)
|
||||||
|
return;
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||||
|
if (!denseAttr)
|
||||||
|
return;
|
||||||
|
|
||||||
|
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
||||||
|
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||||
|
char* dst = memoryBuffer.data() + memEntry.address;
|
||||||
|
|
||||||
|
if (denseAttr.isSplat()) {
|
||||||
|
size_t elementSize = rawData.size();
|
||||||
|
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
||||||
|
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
||||||
|
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
||||||
|
std::memcpy(dst, rawData.data(), rawData.size());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
||||||
|
memoryFileStream.close();
|
||||||
|
return CompilerSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
size_t maxCoreId,
|
||||||
|
json::Object xbarsPerArrayGroup,
|
||||||
|
StringRef outputDirPath) {
|
||||||
|
json::Object configJson;
|
||||||
|
|
||||||
|
configJson["core_cnt"] = maxCoreId + 1;
|
||||||
|
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||||
|
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||||
|
|
||||||
|
json::Array inputsAddresses;
|
||||||
|
for (BlockArgument input : funcOp.getArguments())
|
||||||
|
inputsAddresses.push_back(memory.getValueAddress(input));
|
||||||
|
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
||||||
|
|
||||||
|
json::Array outputsAddresses;
|
||||||
|
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||||
|
for (mlir::Value output : returnOp.getOperands())
|
||||||
|
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||||
|
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||||
|
|
||||||
|
auto configPath = (outputDirPath + "/config.json").str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream jsonOS(configPath, errorCode);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
jsonOS << json::Value(std::move(configJson)) << '\n';
|
||||||
|
jsonOS.close();
|
||||||
|
return CompilerSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
class PimAcceleratorMemory;
|
||||||
|
|
||||||
|
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||||
|
mlir::func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
llvm::StringRef outputDirPath);
|
||||||
|
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
|
||||||
|
PimAcceleratorMemory& memory,
|
||||||
|
size_t maxCoreId,
|
||||||
|
llvm::json::Object xbarsPerArrayGroup,
|
||||||
|
llvm::StringRef outputDirPath);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||||
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||||
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||||
|
SmallVector<int32_t> laneCoreIds;
|
||||||
|
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||||
|
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||||
|
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||||
|
return laneCoreIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||||
|
pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
OperationFolder& constantFolder) {
|
||||||
|
Block& oldBlock = coreBatchOp.getBody().front();
|
||||||
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||||
|
size_t weightCount = coreBatchOp.getWeights().size();
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||||
|
if (blockArg.getType().isIndex()) {
|
||||||
|
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argIndex <= weightCount) {
|
||||||
|
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t inputIndex = argIndex - 1 - weightCount;
|
||||||
|
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
||||||
|
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : oldBlock) {
|
||||||
|
if (isa<pim::PimHaltOp>(op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||||
|
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||||
|
pim::PimSendOp::create(
|
||||||
|
builder,
|
||||||
|
sendBatchOp.getLoc(),
|
||||||
|
mapper.lookup(sendBatchOp.getInput()),
|
||||||
|
sendBatchOp.getSizeAttr(),
|
||||||
|
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||||
|
pim::PimSendTensorOp::create(
|
||||||
|
builder,
|
||||||
|
sendTensorBatchOp.getLoc(),
|
||||||
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||||
|
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||||
|
auto scalarReceive = pim::PimReceiveOp::create(
|
||||||
|
builder,
|
||||||
|
receiveBatchOp.getLoc(),
|
||||||
|
receiveBatchOp.getOutput().getType(),
|
||||||
|
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||||
|
receiveBatchOp.getSizeAttr(),
|
||||||
|
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
||||||
|
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||||
|
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||||
|
builder,
|
||||||
|
receiveTensorBatchOp.getLoc(),
|
||||||
|
receiveTensorBatchOp.getOutput().getType(),
|
||||||
|
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||||
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||||
|
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||||
|
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
||||||
|
builder,
|
||||||
|
memcpBatchOp.getLoc(),
|
||||||
|
memcpBatchOp.getOutput().getType(),
|
||||||
|
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||||
|
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||||
|
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||||
|
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||||
|
memcpBatchOp.getSizeAttr());
|
||||||
|
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = builder.clone(op, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
ArrayRef<unsigned> lanes,
|
||||||
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||||
|
assert(!lanes.empty() && "expected at least one batch lane");
|
||||||
|
|
||||||
|
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||||
|
OpBuilder builder(scratchModule->getContext());
|
||||||
|
OperationFolder constantFolder(scratchModule->getContext());
|
||||||
|
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||||
|
|
||||||
|
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
||||||
|
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
int32_t coreId = coreIds[lanes.front()];
|
||||||
|
for (unsigned lane : lanes)
|
||||||
|
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
||||||
|
|
||||||
|
auto scalarCore =
|
||||||
|
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||||
|
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||||
|
builder.setInsertionPointToEnd(block);
|
||||||
|
for (unsigned lane : lanes)
|
||||||
|
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||||
|
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||||
|
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||||
|
return callback(scalarCore);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||||
|
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||||
|
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
llvm::ArrayRef<unsigned> lanes,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,374 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "llvm/Support/Endian.h"
|
||||||
|
#include "llvm/Support/JSON.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim_binary {
|
||||||
|
|
||||||
|
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
|
||||||
|
inline constexpr uint32_t kVersion = 1;
|
||||||
|
inline constexpr uint64_t kCountOffset = 8;
|
||||||
|
inline constexpr size_t kHeaderSize = 12;
|
||||||
|
inline constexpr size_t kRecordSize = 20;
|
||||||
|
|
||||||
|
enum class Opcode : uint32_t {
|
||||||
|
nop = 0,
|
||||||
|
sldi = 1,
|
||||||
|
sld = 2,
|
||||||
|
sadd = 3,
|
||||||
|
ssub = 4,
|
||||||
|
smul = 5,
|
||||||
|
saddi = 6,
|
||||||
|
smuli = 7,
|
||||||
|
setbw = 8,
|
||||||
|
mvmul = 9,
|
||||||
|
vvadd = 10,
|
||||||
|
vvsub = 11,
|
||||||
|
vvmul = 12,
|
||||||
|
vvdmul = 13,
|
||||||
|
vvmax = 14,
|
||||||
|
vvsll = 15,
|
||||||
|
vvsra = 16,
|
||||||
|
vavg = 17,
|
||||||
|
vrelu = 18,
|
||||||
|
vtanh = 19,
|
||||||
|
vsigm = 20,
|
||||||
|
vsoftmax = 21,
|
||||||
|
vmv = 22,
|
||||||
|
vrsu = 23,
|
||||||
|
vrsl = 24,
|
||||||
|
ld = 25,
|
||||||
|
st = 26,
|
||||||
|
lldi = 27,
|
||||||
|
lmv = 28,
|
||||||
|
send = 29,
|
||||||
|
recv = 30,
|
||||||
|
wait = 31,
|
||||||
|
sync = 32,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InstructionRecord {
|
||||||
|
Opcode opcode = Opcode::nop;
|
||||||
|
uint8_t rd = 0;
|
||||||
|
uint8_t r1 = 0;
|
||||||
|
int32_t r2OrImm = 0;
|
||||||
|
int32_t generic1 = 0;
|
||||||
|
int32_t generic2 = 0;
|
||||||
|
int32_t generic3 = 0;
|
||||||
|
uint8_t flags = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
||||||
|
std::array<char, sizeof(uint32_t)> bytes;
|
||||||
|
llvm::support::endian::write32le(bytes.data(), value);
|
||||||
|
os.write(bytes.data(), bytes.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||||
|
|
||||||
|
inline void writeHeader(llvm::raw_ostream& os) {
|
||||||
|
os.write(kMagic, sizeof(kMagic));
|
||||||
|
writeUint32LE(os, kVersion);
|
||||||
|
writeUint32LE(os, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
|
||||||
|
std::array<char, sizeof(uint32_t)> bytes;
|
||||||
|
llvm::support::endian::write32le(bytes.data(), instructionCount);
|
||||||
|
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
|
||||||
|
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
|
||||||
|
os << static_cast<char>(record.rd);
|
||||||
|
os << static_cast<char>(record.r1);
|
||||||
|
os << static_cast<char>(record.flags);
|
||||||
|
writeInt32LE(os, record.r2OrImm);
|
||||||
|
writeInt32LE(os, record.generic1);
|
||||||
|
writeInt32LE(os, record.generic2);
|
||||||
|
writeInt32LE(os, record.generic3);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int32_t toI32(int64_t value) {
|
||||||
|
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
|
||||||
|
&& "PIM binary field out of int32 range");
|
||||||
|
return static_cast<int32_t>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint8_t toU8(int64_t value) {
|
||||||
|
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
|
||||||
|
return static_cast<uint8_t>(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
||||||
|
if (std::optional<int64_t> value = object.getInteger(key))
|
||||||
|
return toI32(*value);
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Opcode opcodeFromString(llvm::StringRef opName) {
|
||||||
|
if (opName == "nop")
|
||||||
|
return Opcode::nop;
|
||||||
|
if (opName == "sldi")
|
||||||
|
return Opcode::sldi;
|
||||||
|
if (opName == "sld")
|
||||||
|
return Opcode::sld;
|
||||||
|
if (opName == "sadd")
|
||||||
|
return Opcode::sadd;
|
||||||
|
if (opName == "ssub")
|
||||||
|
return Opcode::ssub;
|
||||||
|
if (opName == "smul")
|
||||||
|
return Opcode::smul;
|
||||||
|
if (opName == "saddi")
|
||||||
|
return Opcode::saddi;
|
||||||
|
if (opName == "smuli")
|
||||||
|
return Opcode::smuli;
|
||||||
|
if (opName == "setbw")
|
||||||
|
return Opcode::setbw;
|
||||||
|
if (opName == "mvmul")
|
||||||
|
return Opcode::mvmul;
|
||||||
|
if (opName == "vvadd")
|
||||||
|
return Opcode::vvadd;
|
||||||
|
if (opName == "vvsub")
|
||||||
|
return Opcode::vvsub;
|
||||||
|
if (opName == "vvmul")
|
||||||
|
return Opcode::vvmul;
|
||||||
|
if (opName == "vvdmul")
|
||||||
|
return Opcode::vvdmul;
|
||||||
|
if (opName == "vvmax")
|
||||||
|
return Opcode::vvmax;
|
||||||
|
if (opName == "vvsll")
|
||||||
|
return Opcode::vvsll;
|
||||||
|
if (opName == "vvsra")
|
||||||
|
return Opcode::vvsra;
|
||||||
|
if (opName == "vavg")
|
||||||
|
return Opcode::vavg;
|
||||||
|
if (opName == "vrelu")
|
||||||
|
return Opcode::vrelu;
|
||||||
|
if (opName == "vtanh")
|
||||||
|
return Opcode::vtanh;
|
||||||
|
if (opName == "vsigm")
|
||||||
|
return Opcode::vsigm;
|
||||||
|
if (opName == "vsoftmax")
|
||||||
|
return Opcode::vsoftmax;
|
||||||
|
if (opName == "vmv")
|
||||||
|
return Opcode::vmv;
|
||||||
|
if (opName == "vrsu")
|
||||||
|
return Opcode::vrsu;
|
||||||
|
if (opName == "vrsl")
|
||||||
|
return Opcode::vrsl;
|
||||||
|
if (opName == "ld")
|
||||||
|
return Opcode::ld;
|
||||||
|
if (opName == "st")
|
||||||
|
return Opcode::st;
|
||||||
|
if (opName == "lldi")
|
||||||
|
return Opcode::lldi;
|
||||||
|
if (opName == "lmv")
|
||||||
|
return Opcode::lmv;
|
||||||
|
if (opName == "send")
|
||||||
|
return Opcode::send;
|
||||||
|
if (opName == "recv")
|
||||||
|
return Opcode::recv;
|
||||||
|
if (opName == "wait")
|
||||||
|
return Opcode::wait;
|
||||||
|
if (opName == "sync")
|
||||||
|
return Opcode::sync;
|
||||||
|
llvm_unreachable("Unsupported PIM binary opcode");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
||||||
|
switch (opcode) {
|
||||||
|
case Opcode::nop: return "nop";
|
||||||
|
case Opcode::sldi: return "sldi";
|
||||||
|
case Opcode::sld: return "sld";
|
||||||
|
case Opcode::sadd: return "sadd";
|
||||||
|
case Opcode::ssub: return "ssub";
|
||||||
|
case Opcode::smul: return "smul";
|
||||||
|
case Opcode::saddi: return "saddi";
|
||||||
|
case Opcode::smuli: return "smuli";
|
||||||
|
case Opcode::setbw: return "setbw";
|
||||||
|
case Opcode::mvmul: return "mvmul";
|
||||||
|
case Opcode::vvadd: return "vvadd";
|
||||||
|
case Opcode::vvsub: return "vvsub";
|
||||||
|
case Opcode::vvmul: return "vvmul";
|
||||||
|
case Opcode::vvdmul: return "vvdmul";
|
||||||
|
case Opcode::vvmax: return "vvmax";
|
||||||
|
case Opcode::vvsll: return "vvsll";
|
||||||
|
case Opcode::vvsra: return "vvsra";
|
||||||
|
case Opcode::vavg: return "vavg";
|
||||||
|
case Opcode::vrelu: return "vrelu";
|
||||||
|
case Opcode::vtanh: return "vtanh";
|
||||||
|
case Opcode::vsigm: return "vsigm";
|
||||||
|
case Opcode::vsoftmax: return "vsoftmax";
|
||||||
|
case Opcode::vmv: return "vmv";
|
||||||
|
case Opcode::vrsu: return "vrsu";
|
||||||
|
case Opcode::vrsl: return "vrsl";
|
||||||
|
case Opcode::ld: return "ld";
|
||||||
|
case Opcode::st: return "st";
|
||||||
|
case Opcode::lldi: return "lldi";
|
||||||
|
case Opcode::lmv: return "lmv";
|
||||||
|
case Opcode::send: return "send";
|
||||||
|
case Opcode::recv: return "recv";
|
||||||
|
case Opcode::wait: return "wait";
|
||||||
|
case Opcode::sync: return "sync";
|
||||||
|
}
|
||||||
|
llvm_unreachable("Unsupported PIM binary opcode");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
|
||||||
|
InstructionRecord record;
|
||||||
|
std::optional<llvm::StringRef> opName = instruction.getString("op");
|
||||||
|
assert(opName && "Missing op field in PIM instruction");
|
||||||
|
record.opcode = opcodeFromString(*opName);
|
||||||
|
record.rd = toU8(getOptionalInt(instruction, "rd"));
|
||||||
|
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
|
||||||
|
|
||||||
|
switch (record.opcode) {
|
||||||
|
case Opcode::sldi:
|
||||||
|
case Opcode::saddi:
|
||||||
|
case Opcode::smuli:
|
||||||
|
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||||
|
case Opcode::mvmul:
|
||||||
|
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||||
|
record.generic1 = getOptionalInt(instruction, "relu");
|
||||||
|
record.generic2 = getOptionalInt(instruction, "group");
|
||||||
|
break;
|
||||||
|
case Opcode::setbw:
|
||||||
|
record.generic1 = getOptionalInt(instruction, "ibiw");
|
||||||
|
record.generic2 = getOptionalInt(instruction, "obiw");
|
||||||
|
break;
|
||||||
|
case Opcode::send:
|
||||||
|
case Opcode::recv:
|
||||||
|
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||||
|
record.generic3 = getOptionalInt(instruction, "size");
|
||||||
|
break;
|
||||||
|
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||||
|
if (auto* offsetValue = instruction.getObject("offset")) {
|
||||||
|
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
|
||||||
|
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (instruction.get("len"))
|
||||||
|
record.generic3 = getOptionalInt(instruction, "len");
|
||||||
|
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
|
||||||
|
record.generic3 = getOptionalInt(instruction, "size");
|
||||||
|
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
||||||
|
llvm::json::Object instruction;
|
||||||
|
instruction["op"] = opcodeToString(record.opcode).str();
|
||||||
|
|
||||||
|
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
|
||||||
|
llvm::json::Object offset;
|
||||||
|
offset["offset_select"] = offsetSelect;
|
||||||
|
offset["offset_value"] = offsetValue;
|
||||||
|
instruction["offset"] = std::move(offset);
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (record.opcode) {
|
||||||
|
case Opcode::sldi:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["imm"] = record.r2OrImm;
|
||||||
|
break;
|
||||||
|
case Opcode::sld:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
break;
|
||||||
|
case Opcode::sadd:
|
||||||
|
case Opcode::ssub:
|
||||||
|
case Opcode::smul:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
instruction["rs2"] = record.r2OrImm;
|
||||||
|
break;
|
||||||
|
case Opcode::saddi:
|
||||||
|
case Opcode::smuli:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
instruction["imm"] = record.r2OrImm;
|
||||||
|
break;
|
||||||
|
case Opcode::setbw:
|
||||||
|
instruction["ibiw"] = record.generic1;
|
||||||
|
instruction["obiw"] = record.generic2;
|
||||||
|
break;
|
||||||
|
case Opcode::mvmul:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
instruction["mbiw"] = record.r2OrImm;
|
||||||
|
instruction["relu"] = record.generic1;
|
||||||
|
instruction["group"] = record.generic2;
|
||||||
|
break;
|
||||||
|
case Opcode::vvadd:
|
||||||
|
case Opcode::vvsub:
|
||||||
|
case Opcode::vvmul:
|
||||||
|
case Opcode::vvdmul:
|
||||||
|
case Opcode::vvmax:
|
||||||
|
case Opcode::vvsll:
|
||||||
|
case Opcode::vvsra:
|
||||||
|
case Opcode::vavg:
|
||||||
|
case Opcode::vmv:
|
||||||
|
case Opcode::vrsu:
|
||||||
|
case Opcode::vrsl:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
instruction["rs2"] = record.r2OrImm;
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["len"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::vrelu:
|
||||||
|
case Opcode::vtanh:
|
||||||
|
case Opcode::vsigm:
|
||||||
|
case Opcode::vsoftmax:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["len"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::ld:
|
||||||
|
case Opcode::st:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["size"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::lldi:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["imm"] = record.r2OrImm;
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["len"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::lmv:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["len"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::send:
|
||||||
|
case Opcode::recv:
|
||||||
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||||
|
instruction["core"] = record.r2OrImm;
|
||||||
|
addOffset(record.generic1, record.generic2);
|
||||||
|
instruction["size"] = record.generic3;
|
||||||
|
break;
|
||||||
|
case Opcode::wait:
|
||||||
|
case Opcode::sync:
|
||||||
|
case Opcode::nop: break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return instruction;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim_binary
|
||||||
+642
-436
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,19 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -14,16 +23,42 @@ struct MemEntry {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MemoryReportRow {
|
||||||
|
uint64_t numAlloca = 0;
|
||||||
|
uint64_t sizeAlloca = 0;
|
||||||
|
uint64_t numGlobal = 0;
|
||||||
|
uint64_t sizeGlobal = 0;
|
||||||
|
|
||||||
|
bool operator==(const MemoryReportRow& other) const {
|
||||||
|
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
|
||||||
|
&& sizeGlobal == other.sizeGlobal;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MemoryReportEntry {
|
||||||
|
enum class Kind {
|
||||||
|
Core,
|
||||||
|
Batch
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = Kind::Core;
|
||||||
|
uint64_t id = 0;
|
||||||
|
llvm::SmallVector<int32_t, 8> coreIds;
|
||||||
|
MemoryReportRow row;
|
||||||
|
uint64_t totalAllocaCount = 0;
|
||||||
|
uint64_t totalAllocaBytes = 0;
|
||||||
|
};
|
||||||
|
|
||||||
class PimMemory {
|
class PimMemory {
|
||||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||||
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
|
||||||
|
|
||||||
size_t maxSize = 0; // 0 for unbounded memory
|
|
||||||
size_t startAddress = 0;
|
|
||||||
size_t minAlignment = 4;
|
size_t minAlignment = 4;
|
||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
MemEntry* gatherMemEntry(mlir::Value value);
|
MemEntry* gatherMemEntry(mlir::Value value);
|
||||||
|
void allocateGatheredMemory();
|
||||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@@ -32,6 +67,8 @@ public:
|
|||||||
|
|
||||||
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||||
void allocateCore(mlir::Operation* op);
|
void allocateCore(mlir::Operation* op);
|
||||||
|
MemoryReportRow getReportRow() const;
|
||||||
|
void remove(mlir::Value val);
|
||||||
|
|
||||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||||
MemEntry getMemEntry(mlir::Value value) const;
|
MemEntry getMemEntry(mlir::Value value) const;
|
||||||
@@ -44,26 +81,41 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||||
|
std::fstream fileReport;
|
||||||
|
std::optional<MemoryReportRow> hostReportRow;
|
||||||
|
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
: hostMem(memEntriesMap) {}
|
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
||||||
|
|
||||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||||
|
|
||||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||||
|
void reportHost();
|
||||||
|
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||||
|
void recordBatchReport(uint64_t batchId,
|
||||||
|
llvm::ArrayRef<int32_t> coreIds,
|
||||||
|
const MemoryReportRow& perCoreRow,
|
||||||
|
uint64_t totalAllocaCount,
|
||||||
|
uint64_t totalAllocaBytes);
|
||||||
|
void flushReport();
|
||||||
|
void clean(mlir::Operation* op);
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
llvm::raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreBinaryStream;
|
||||||
|
llvm::raw_fd_ostream* coreJsonStream;
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||||
|
mutable uint32_t emittedInstructionCount = 0;
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge);
|
return memory.getValueAddress(value, knowledge);
|
||||||
}
|
}
|
||||||
|
size_t remapCoreId(size_t coreId) const;
|
||||||
|
|
||||||
static llvm::json::Object createEmptyOffset();
|
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||||
void emitInstruction(llvm::json::Object instruction) const;
|
|
||||||
|
|
||||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||||
@@ -82,15 +134,23 @@ class PimCodeGen {
|
|||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
PimCodeGen(PimAcceleratorMemory& memory,
|
||||||
: memory(memory), coreFileStream(coreJson) {}
|
llvm::raw_fd_ostream& coreBinary,
|
||||||
|
llvm::raw_fd_ostream* coreJson,
|
||||||
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||||
|
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||||
|
|
||||||
|
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
||||||
@@ -105,9 +165,10 @@ public:
|
|||||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,18 +1,7 @@
|
|||||||
/*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
*/
|
|
||||||
|
|
||||||
//===------------------------- PimCompilerOptions.cpp --------------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2022 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
// Compiler Options for PIM
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -26,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
|||||||
llvm::cl::init(EmitPimCodegen),
|
llvm::cl::init(EmitPimCodegen),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||||
|
"pim-merge-scheduler",
|
||||||
|
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||||
|
llvm::cl::init(MergeSchedulerPeft),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
pimOnlyCodegen("pim-only-codegen",
|
pimOnlyCodegen("pim-only-codegen",
|
||||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||||
@@ -37,19 +34,39 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
|||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
|
llvm::cl::opt<size_t>
|
||||||
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||||
llvm::cl::init(-1));
|
llvm::cl::init(-1));
|
||||||
|
|
||||||
|
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||||
|
"dcp-critical-window-size",
|
||||||
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
|
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||||
|
llvm::cl::init(4000));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
ignoreConcatError("ignore-concat-error",
|
ignoreConcatError("ignore-concat-error",
|
||||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||||
|
|
||||||
|
void verifyExplicitPimCoreCount() {
|
||||||
|
if (!hasExplicitPimCoreCount())
|
||||||
|
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||||
|
if (coresCount.getValue() <= 0)
|
||||||
|
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -20,15 +20,26 @@ typedef enum {
|
|||||||
EmitPimCodegen = 3
|
EmitPimCodegen = 3
|
||||||
} PimEmissionTargetType;
|
} PimEmissionTargetType;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
MergeSchedulerPeft = 0,
|
||||||
|
MergeSchedulerDcp = 1,
|
||||||
|
} PimMergeSchedulerType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
|
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
|
extern llvm::cl::opt<bool> pimEmitJson;
|
||||||
|
|
||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
|
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount();
|
||||||
|
void verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
// This option, by default set to false, will ignore an error when resolving a
|
// This option, by default set to false, will ignore an error when resolving a
|
||||||
// specific tiles of the operands of a concat. This specific case is when the
|
// specific tiles of the operands of a concat. This specific case is when the
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
PassManager& pm,
|
PassManager& pm,
|
||||||
EmissionTargetType& emissionTarget,
|
EmissionTargetType& emissionTarget,
|
||||||
std::string outputNameNoExt) {
|
std::string outputNameNoExt) {
|
||||||
|
verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
if (pimOnlyCodegen) {
|
if (pimOnlyCodegen) {
|
||||||
// Skip all the lowering passes and directly generate code for PIM.
|
// Skip all the lowering passes and directly generate code for PIM.
|
||||||
@@ -41,6 +42,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||||
pm.addPass(createPimBufferizationPass());
|
pm.addPass(createPimBufferizationPass());
|
||||||
|
pm.addPass(createPimStaticMemoryCoalescingPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim bufferized"));
|
pm.addPass(createMessagePass("Pim bufferized"));
|
||||||
}
|
}
|
||||||
@@ -51,9 +53,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||||
pm.addPass(createPimVerificationPass());
|
pm.addPass(createPimVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimCodePass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
pm.addPass(createMessagePass("Pim code emitted"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,258 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FileSystem.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct DenseWeightView {
|
||||||
|
DenseElementsAttr denseAttr;
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
int64_t offset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||||
|
SmallVector<Operation*> viewOps;
|
||||||
|
mlir::Value current = weight;
|
||||||
|
memref::GetGlobalOp getGlobalOp;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
Operation* defOp = current.getDefiningOp();
|
||||||
|
if (!defOp)
|
||||||
|
return failure();
|
||||||
|
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||||
|
break;
|
||||||
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||||
|
if (!hasAllStaticSubviewParts(subview))
|
||||||
|
return failure();
|
||||||
|
viewOps.push_back(subview);
|
||||||
|
current = subview.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||||
|
current = cast.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
viewOps.push_back(collapse);
|
||||||
|
current = collapse.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
viewOps.push_back(expand);
|
||||||
|
current = expand.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseWeightView view;
|
||||||
|
view.denseAttr = denseAttr;
|
||||||
|
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||||
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
|
||||||
|
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||||
|
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||||
|
SmallVector<int64_t> nextStrides;
|
||||||
|
nextStrides.reserve(subview.getStaticStrides().size());
|
||||||
|
for (auto [offset, stride, sourceStride] :
|
||||||
|
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||||
|
view.offset += offset * sourceStride;
|
||||||
|
nextStrides.push_back(stride * sourceStride);
|
||||||
|
}
|
||||||
|
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||||
|
view.strides = std::move(nextStrides);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||||
|
// dense global view, so a row-major stride recomputation preserves layout.
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||||
|
if (view.strides != computeRowMajorStrides(view.shape))
|
||||||
|
return failure();
|
||||||
|
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||||
|
if (view.strides != computeRowMajorStrides(view.shape))
|
||||||
|
return failure();
|
||||||
|
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||||
|
SmallVector<unsigned, 8> indices;
|
||||||
|
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||||
|
auto addWeight = [&](mlir::Value weight) {
|
||||||
|
if (!coreOp)
|
||||||
|
return;
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||||
|
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||||
|
continue;
|
||||||
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
|
indices.push_back(weightIndex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||||
|
llvm::sort(indices);
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
|
return getUsedWeightIndices(coreOp.getBody().front());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||||
|
SmallVector<Operation*> coreLikeOps;
|
||||||
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||||
|
coreLikeOps.push_back(&op);
|
||||||
|
return coreLikeOps;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
|
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||||
|
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||||
|
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||||
|
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||||
|
assert(!error && "Error creating weights directory");
|
||||||
|
size_t indexFileName = 0;
|
||||||
|
|
||||||
|
int64_t xbarSize = crossbarSize.getValue();
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||||
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
|
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||||
|
|
||||||
|
for (Operation* op : coreLikeOps) {
|
||||||
|
auto processCore = [&](pim::PimCoreOp coreOp) {
|
||||||
|
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||||
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
|
||||||
|
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||||
|
if (failed(weightView)) {
|
||||||
|
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||||
|
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||||
|
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||||
|
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||||
|
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||||
|
ArrayRef<int64_t> shape = weightView->shape;
|
||||||
|
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||||
|
int64_t numRows = shape[0];
|
||||||
|
int64_t numCols = shape[1];
|
||||||
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
|
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||||
|
assert(errorCode);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t zero = 0;
|
||||||
|
for (int64_t row = 0; row < xbarSize; row++) {
|
||||||
|
for (int64_t col = 0; col < xbarSize; col++) {
|
||||||
|
if (row < numRows && col < numCols) {
|
||||||
|
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||||
|
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||||
|
uint64_t word = bits.getZExtValue();
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
weightFileStream.close();
|
||||||
|
if (globalOp)
|
||||||
|
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||||
|
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
(void) processCore(coreOp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||||
|
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
|
||||||
|
return mapCoreWeightToFileName;
|
||||||
|
}
|
||||||
|
return mapCoreWeightToFileName;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
|
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_pim_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
|
ConversionPatterns.cpp
|
||||||
|
HostFoldability.cpp
|
||||||
|
HostLegality.cpp
|
||||||
|
PrePatterns.cpp
|
||||||
|
PostPatterns.cpp
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Patterns/Math/Elementwise.cpp
|
Patterns/Math/Elementwise.cpp
|
||||||
Patterns/Math/Gemm.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
@@ -18,7 +23,9 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
Common.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
|
Common/ShapeTilingUtils.cpp
|
||||||
|
Common/WeightMaterialization.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -1,279 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getImageN(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class ShapedType>
|
|
||||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
|
||||||
return shapedType.getDimSize(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
using HSliceId = size_t;
|
|
||||||
using CoreId = size_t;
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr C ceilIntegerDivide(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return 1 + (ac - 1) / bc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[0] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[1] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
|
||||||
assert(isVectorShape(shape));
|
|
||||||
return shape[0] != 1 ? shape[0] : shape[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto getTensorShape(mlir::Value tensor) {
|
|
||||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isWeightLikeComputeOperand(mlir::Value value) {
|
|
||||||
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
|
|
||||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
|
|
||||||
|
|
||||||
while (auto* definingOp = value.getDefiningOp()) {
|
|
||||||
if (!visited.insert(definingOp).second)
|
|
||||||
return false;
|
|
||||||
if (hasWeightAlways(definingOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
|
|
||||||
value = extractSliceOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
|
|
||||||
value = transposeOp.getData();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
|
||||||
return std::forward<Fn>(fn)(values[Is]...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t>
|
|
||||||
using ValueArg = mlir::Value;
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
struct InvokeWithBlockArgsResult;
|
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
|
||||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
|
||||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Fn, typename Seq>
|
|
||||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
|
||||||
|
|
||||||
template <typename Fn>
|
|
||||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult =
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename RewriterT, typename BodyFn>
|
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::TypeRange resultTypes,
|
|
||||||
mlir::ValueRange weights,
|
|
||||||
mlir::ValueRange inputs,
|
|
||||||
BodyFn&& body) {
|
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
|
||||||
for (mlir::Value input : inputs)
|
|
||||||
block->addArgument(input.getType(), loc);
|
|
||||||
|
|
||||||
computeOp.getBody().push_back(block);
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
|
||||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
if (mlir::failed(bodyResult)) {
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
rewriter.eraseOp(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
|
||||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
return computeOp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
|
||||||
size_t axis,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
|
||||||
int64_t sliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
|
||||||
tileMatrix(mlir::Value& matrixToTile,
|
|
||||||
int64_t hSliceSize,
|
|
||||||
int64_t vSliceSize,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location& loc);
|
|
||||||
|
|
||||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
|
||||||
int64_t length,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "ComputeRegionBuilder.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||||
|
if (tensors.size() == 1)
|
||||||
|
return tensors[0];
|
||||||
|
|
||||||
|
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||||
|
SmallVector<Value> tensors2;
|
||||||
|
tensors2.reserve(tensors.size() / 2);
|
||||||
|
|
||||||
|
auto* currTensors = &tensors1;
|
||||||
|
auto* nextTensors = &tensors2;
|
||||||
|
while (currTensors->size() > 1) {
|
||||||
|
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||||
|
Value a = (*currTensors)[i];
|
||||||
|
Value b = (*currTensors)[i + 1];
|
||||||
|
rewriter.setInsertionPointAfterValue(b);
|
||||||
|
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||||
|
nextTensors->push_back(addedValue);
|
||||||
|
}
|
||||||
|
if (currTensors->size() % 2 == 1)
|
||||||
|
nextTensors->push_back(currTensors->back());
|
||||||
|
std::swap(currTensors, nextTensors);
|
||||||
|
nextTensors->clear();
|
||||||
|
}
|
||||||
|
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||||
|
return (*currTensors)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||||
|
|
||||||
|
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||||
|
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||||
|
return std::forward<Fn>(fn)(values[Is]...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t>
|
||||||
|
using ValueArg = mlir::Value;
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
struct InvokeWithBlockArgsResult;
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||||
|
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename RewriterT>
|
||||||
|
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||||
|
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||||
|
if (inputs.size() == 1)
|
||||||
|
return inputs.front();
|
||||||
|
|
||||||
|
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||||
|
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
bool concatDimDynamic = false;
|
||||||
|
|
||||||
|
for (mlir::Value input : inputs) {
|
||||||
|
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||||
|
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||||
|
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||||
|
concatDimDynamic = true;
|
||||||
|
else
|
||||||
|
concatDimSize += inputType.getDimSize(axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||||
|
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||||
|
/// the body callback reports failure.
|
||||||
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value weight : weights)
|
||||||
|
block->addArgument(weight.getType(), loc);
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
detail::invokeWithValues(
|
||||||
|
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult = detail::invokeWithValues(
|
||||||
|
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||||
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value weight : weights)
|
||||||
|
block->addArgument(weight.getType(), loc);
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
|
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+39
-50
@@ -1,24 +1,12 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -44,10 +32,29 @@ SmallVector<Value> sliceTensor(
|
|||||||
|
|
||||||
for (int64_t i = 0; i < numSlices; i++) {
|
for (int64_t i = 0; i < numSlices; i++) {
|
||||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
int64_t currentSliceSize = sliceSize;
|
||||||
|
if (i == numSlices - 1 && lastSliceSize != 0) {
|
||||||
|
currentSliceSize = lastSliceSize;
|
||||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||||
|
}
|
||||||
|
|
||||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
||||||
|
sliceShape[axis] = currentSliceSize;
|
||||||
|
auto sliceType =
|
||||||
|
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||||
|
|
||||||
|
Value slice;
|
||||||
|
if (isHostFoldableValue(tensorToSlice)) {
|
||||||
|
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto sliceCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
||||||
|
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
||||||
|
});
|
||||||
|
slice = sliceCompute.getResult(0);
|
||||||
|
}
|
||||||
slices.push_back(slice);
|
slices.push_back(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,45 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
|||||||
return tiles;
|
return tiles;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::SplatOp
|
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||||
Type elementType = oldType.getElementType();
|
Type elementType = oldType.getElementType();
|
||||||
int64_t shape[2] = {1, length};
|
int64_t shape[2] = {1, length};
|
||||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||||
|
|
||||||
|
auto buildBroadcast = [&](Value input) -> Value {
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
SmallVector<Value> index(oldType.getRank(), zero);
|
SmallVector<Value> index(oldType.getRank(), zero);
|
||||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||||
|
|
||||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(scalarToBroadcast))
|
||||||
|
return buildBroadcast(scalarToBroadcast);
|
||||||
|
|
||||||
|
auto broadcastCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||||
|
});
|
||||||
|
return broadcastCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
} // namespace onnx_mlir
|
||||||
if (tensors.size() == 1)
|
|
||||||
return tensors[0];
|
|
||||||
|
|
||||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
|
||||||
SmallVector<Value> tensors2;
|
|
||||||
tensors2.reserve(tensors.size() / 2);
|
|
||||||
|
|
||||||
auto* currTensors = &tensors1;
|
|
||||||
auto* nextTensors = &tensors2;
|
|
||||||
while (currTensors->size() > 1) {
|
|
||||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
|
||||||
Value a = (*currTensors)[i];
|
|
||||||
Value b = (*currTensors)[i + 1];
|
|
||||||
rewriter.setInsertionPointAfterValue(b);
|
|
||||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
|
||||||
nextTensors->push_back(addedValue);
|
|
||||||
}
|
|
||||||
if (currTensors->size() % 2 == 1)
|
|
||||||
nextTensors->push_back(currTensors->back());
|
|
||||||
std::swap(currTensors, nextTensors);
|
|
||||||
nextTensors->clear();
|
|
||||||
}
|
|
||||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
|
||||||
return (*currTensors)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageN(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
using HSliceId = size_t;
|
||||||
|
using CoreId = size_t;
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr C ceilIntegerDivide(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return 1 + (ac - 1) / bc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[1] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||||
|
assert(isVectorShape(shape));
|
||||||
|
return shape[0] != 1 ? shape[0] : shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline auto getTensorShape(mlir::Value tensor) {
|
||||||
|
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||||
|
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||||
|
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||||
|
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||||
|
&& lhsType.getShape() == rhsType.getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||||
|
/// at most `sliceSize` elements.
|
||||||
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
|
size_t axis,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||||
|
/// current PIM target geometry.
|
||||||
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
|
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
|
/// Tiles a matrix first across output columns and then across input rows so it
|
||||||
|
/// can be assigned to crossbars grouped by core.
|
||||||
|
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||||
|
tileMatrix(mlir::Value& matrixToTile,
|
||||||
|
int64_t hSliceSize,
|
||||||
|
int64_t vSliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location& loc);
|
||||||
|
|
||||||
|
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||||
|
int64_t length,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isWeightLikeComputeOperand(Value value) {
|
||||||
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
|
||||||
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return false;
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
value = transposeOp.getData();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(value))
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!tensorType || !tensorType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
sizes.reserve(tensorType.getRank());
|
||||||
|
for (int64_t dim : tensorType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
|
||||||
|
auto referencedValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||||
|
mapper.map(value, referencedValue.getResult());
|
||||||
|
return referencedValue.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
IRMapping localMapper;
|
||||||
|
for (Value operand : definingOp->getOperands()) {
|
||||||
|
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||||
|
localMapper.map(operand, cast<Value>(mapped));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isWeightLikeComputeOperand(operand)) {
|
||||||
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||||
|
if (failed(clonedOperand))
|
||||||
|
return failure();
|
||||||
|
localMapper.map(operand, *clonedOperand);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
localMapper.map(operand, operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||||
|
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapper.map(oldResult, newResult);
|
||||||
|
|
||||||
|
auto mapped = mapper.lookupOrNull(value);
|
||||||
|
if (!mapped)
|
||||||
|
return failure();
|
||||||
|
return cast<Value>(mapped);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/// Returns true when a matrix-valued compute operand is ultimately backed by a
|
||||||
|
/// weight-marked constant/view chain and can be promoted into weights.
|
||||||
|
bool isWeightLikeComputeOperand(mlir::Value value);
|
||||||
|
|
||||||
|
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
|
||||||
|
/// compute body while reusing already-materialized intermediate values.
|
||||||
|
llvm::FailureOr<mlir::Value>
|
||||||
|
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||||
|
patterns.add<removeLRN>(ctx);
|
||||||
|
|
||||||
|
populateElementwisePatterns(patterns, ctx);
|
||||||
|
populateGemmPatterns(patterns, ctx);
|
||||||
|
populateConvPatterns(patterns, ctx);
|
||||||
|
populatePoolPatterns(patterns, ctx);
|
||||||
|
populateReduceMeanPatterns(patterns, ctx);
|
||||||
|
populateReluPatterns(patterns, ctx);
|
||||||
|
populateSigmoidPatterns(patterns, ctx);
|
||||||
|
populateSoftmaxPatterns(patterns, ctx);
|
||||||
|
populateConcatPatterns(patterns, ctx);
|
||||||
|
populateGatherPatterns(patterns, ctx);
|
||||||
|
populateResizePatterns(patterns, ctx);
|
||||||
|
populateReshapePatterns(patterns, ctx);
|
||||||
|
populateSplitPatterns(patterns, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+2
@@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||||
|
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||||
|
return llvm::all_of(extractOp.getIndices(),
|
||||||
|
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isStaticTensorResult(Operation* op) {
|
||||||
|
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||||
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
|
return shapedType && shapedType.hasStaticShape();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!tensorType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
if (static_cast<int64_t>(perms.size()) != rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
llvm::SmallBitVector seen(rank);
|
||||||
|
SmallVector<int64_t> transposedShape;
|
||||||
|
transposedShape.reserve(rank);
|
||||||
|
for (int64_t perm : perms) {
|
||||||
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||||
|
return failure();
|
||||||
|
seen.set(perm);
|
||||||
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||||
|
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||||
|
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||||
|
SmallVector<int64_t> originalIndices(rank);
|
||||||
|
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||||
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
originalIndices[dim] = remaining / originalStrides[dim];
|
||||||
|
remaining %= originalStrides[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t transposedLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||||
|
|
||||||
|
transposedValues[transposedLinearIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||||
|
return DenseElementsAttr::get(resultType, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||||
|
tensor::ExtractSliceOp extractSliceOp) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||||
|
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||||
|
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||||
|
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||||
|
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||||
|
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||||
|
SmallVector<Attribute> resultValues;
|
||||||
|
resultValues.reserve(resultType.getNumElements());
|
||||||
|
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||||
|
int64_t remaining = linearIndex;
|
||||||
|
int64_t sourceLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||||
|
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||||
|
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||||
|
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||||
|
}
|
||||||
|
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(resultType, resultValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||||
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||||
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||||
|
auto* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp || !visited.insert(definingOp).second)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||||
|
// lowering stages can still recognize grouped/sliced constants.
|
||||||
|
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||||
|
return denseAttr;
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
perm.reserve(transposeOp.getPermAttr().size());
|
||||||
|
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||||
|
perm.push_back(attr.getInt());
|
||||||
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||||
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||||
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||||
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||||
|
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||||
|
if (!op || !visited.insert(op).second)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||||
|
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||||
|
|
||||||
|
if (!isStaticTensorResult(op))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||||
|
return isHostFoldableValue(transposeOp.getData());
|
||||||
|
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||||
|
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||||
|
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||||
|
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
|
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||||
|
|
||||||
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||||
|
return isHostFoldableValue(splatOp.getInput());
|
||||||
|
|
||||||
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||||
|
return isHostFoldableValue(extractRowsOp.getInput());
|
||||||
|
|
||||||
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||||
|
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isHostFoldableValue(Value value) {
|
||||||
|
auto* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return isHostFoldableOpImpl(definingOp, visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isHostFoldableOp(Operation* op) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return isHostFoldableOpImpl(op, visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool isHostFoldableValue(mlir::Value value);
|
||||||
|
|
||||||
|
bool isHostFoldableOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||||
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
|
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||||
|
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||||
|
continue;
|
||||||
|
if (isHostFoldableOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||||
|
"spat.compute");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||||
|
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,31 +1,23 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
#include "Common/Common.hpp"
|
||||||
#include <iterator>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -33,12 +25,8 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool haveSameStaticShape(Value lhs, Value rhs);
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
|
||||||
|
|
||||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||||
@@ -48,35 +36,118 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
|
|
||||||
private:
|
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
|
||||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
IRMapping mapper;
|
||||||
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||||
|
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||||
|
if (!computes.empty() || !computeBatches.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
|
||||||
|
SmallVector<Type> sourceTypes;
|
||||||
|
SmallVector<Location> sourceLocs;
|
||||||
|
sourceTypes.reserve(funcOp.getNumArguments());
|
||||||
|
sourceLocs.reserve(funcOp.getNumArguments());
|
||||||
|
for (Value source : funcOp.getArguments()) {
|
||||||
|
sourceTypes.push_back(source.getType());
|
||||||
|
sourceLocs.push_back(source.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newCompute = spatial::SpatCompute::create(
|
||||||
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||||
|
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||||
|
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||||
|
mapper.map(computeArg, blockArg);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
for (Operation& op : funcOp.getOps())
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||||
|
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||||
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
||||||
|
op.dropAllUses();
|
||||||
|
rewriter.eraseOp(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||||
|
returnOp.setOperand(index, computeResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||||
|
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||||
|
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Transpose stays globally legal because constant/view-only cases are
|
||||||
|
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||||
|
// spat.compute before the host legality check.
|
||||||
|
auto resultType = transposeOp.getResult().getType();
|
||||||
|
rewriter.setInsertionPoint(transposeOp);
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||||
|
Value transposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||||
|
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
RewritePatternSet mergeActivationPatterns(ctx);
|
ConversionTarget preTarget(*ctx);
|
||||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
ONNXDialect,
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
tensor::TensorDialect,
|
||||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
arith::ArithDialect,
|
||||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
scf::SCFDialect>();
|
||||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
RewritePatternSet prePatterns(ctx);
|
||||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
populatePrePatterns(prePatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||||
|
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
IRRewriter rewriter(moduleOp);
|
|
||||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (failed(entryFunc)) {
|
if (failed(entryFunc)) {
|
||||||
|
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewritePatternSet matmulPatterns(ctx);
|
||||||
|
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||||
|
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||||
|
|
||||||
|
bool hasUnloweredMatMul = false;
|
||||||
|
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||||
|
hasUnloweredMatMul = true;
|
||||||
|
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||||
|
});
|
||||||
|
if (hasUnloweredMatMul) {
|
||||||
|
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -87,8 +158,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
target.addIllegalOp<ONNXMatMulOp>();
|
||||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
|
||||||
target.addIllegalOp<ONNXAddOp>();
|
target.addIllegalOp<ONNXAddOp>();
|
||||||
target.addIllegalOp<ONNXDivOp>();
|
target.addIllegalOp<ONNXDivOp>();
|
||||||
target.addIllegalOp<ONNXMulOp>();
|
target.addIllegalOp<ONNXMulOp>();
|
||||||
@@ -107,370 +177,60 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||||
target.addIllegalOp<ONNXSplitOp>();
|
target.addIllegalOp<ONNXSplitOp>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet conversionPatterns(ctx);
|
||||||
patterns.add<removeLRN>(ctx);
|
populateConversionPatterns(conversionPatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||||
populateElementwisePatterns(patterns, ctx);
|
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||||
populateGemmPatterns(patterns, ctx);
|
|
||||||
populateConvPatterns(patterns, ctx);
|
|
||||||
populatePoolPatterns(patterns, ctx);
|
|
||||||
populateReduceMeanPatterns(patterns, ctx);
|
|
||||||
populateReluPatterns(patterns, ctx);
|
|
||||||
populateSigmoidPatterns(patterns, ctx);
|
|
||||||
populateSoftmaxPatterns(patterns, ctx);
|
|
||||||
populateConcatPatterns(patterns, ctx);
|
|
||||||
populateGatherPatterns(patterns, ctx);
|
|
||||||
populateResizePatterns(patterns, ctx);
|
|
||||||
populateReshapePatterns(patterns, ctx);
|
|
||||||
populateSplitPatterns(patterns, ctx);
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count the number of compute ops and check they do not exceed the core count
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
if (coresCount != -1) {
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
int computeOpsCount = 0;
|
ONNXDialect,
|
||||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
tensor::TensorDialect,
|
||||||
if (isa<spatial::SpatWeightedCompute>(op))
|
arith::ArithDialect,
|
||||||
computeOpsCount++;
|
scf::SCFDialect>();
|
||||||
|
|
||||||
if (computeOpsCount > coresCount) {
|
|
||||||
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PassManager cleanupPM(ctx);
|
PassManager cleanupPM(ctx);
|
||||||
cleanupPM.addPass(createCanonicalizerPass());
|
cleanupPM.addPass(createCanonicalizerPass());
|
||||||
if (failed(cleanupPM.run(moduleOp)))
|
if (failed(cleanupPM.run(moduleOp)))
|
||||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
ConversionTarget postTarget(*ctx);
|
||||||
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
ONNXDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||||
|
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||||
|
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
|
||||||
|
RewritePatternSet postPatterns(ctx);
|
||||||
|
populatePostPatterns(postPatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
mergeTriviallyConnectedComputes(*entryFunc);
|
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||||
|
|
||||||
|
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||||
|
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
|
||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
|
||||||
Value source = funcSource(toRemoveOp);
|
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
||||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
|
||||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
mapper.map(source, BB->getArgument(0));
|
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
|
||||||
inst->replaceAllUsesWith(newCompute);
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|
||||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
|
||||||
auto sources = toRemoveOp.getInputs();
|
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
||||||
if (llvm::any_of(
|
|
||||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
|
||||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
|
||||||
SmallVector<Type> sourceTypes;
|
|
||||||
SmallVector<Location> sourceLoc;
|
|
||||||
for (auto source : sources) {
|
|
||||||
sourceTypes.push_back(source.getType());
|
|
||||||
sourceLoc.push_back(loc);
|
|
||||||
}
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
|
||||||
mapper.map(source, bbArg);
|
|
||||||
auto newConcat = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
|
||||||
inst->replaceAllUsesWith(newCompute);
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(value))
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
|
||||||
if (!tensorType || !tensorType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
|
||||||
sizes.reserve(tensorType.getRank());
|
|
||||||
for (int64_t dim : tensorType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
|
|
||||||
auto referencedValue =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
|
||||||
mapper.map(value, referencedValue.getResult());
|
|
||||||
return referencedValue.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
IRMapping localMapper;
|
|
||||||
for (Value operand : definingOp->getOperands()) {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
|
||||||
localMapper.map(operand, cast<Value>(mapped));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isWeightLikeComputeOperand(operand)) {
|
|
||||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
|
||||||
if (failed(clonedOperand))
|
|
||||||
return failure();
|
|
||||||
localMapper.map(operand, *clonedOperand);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
localMapper.map(operand, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
|
|
||||||
auto mapped = mapper.lookupOrNull(value);
|
|
||||||
if (!mapped)
|
|
||||||
return failure();
|
|
||||||
return cast<Value>(mapped);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
bool keep = true;
|
|
||||||
while (keep) {
|
|
||||||
keep = false;
|
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
|
||||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<ONNXTransposeOp>(
|
|
||||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
|
||||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
|
||||||
|
|
||||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
|
||||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
|
||||||
if (compute->hasOneUse()) {
|
|
||||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
|
||||||
|
|
||||||
if (user && user.getInputs().size() == 1)
|
|
||||||
trivialComputes.push_back(compute);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!trivialComputes.empty()) {
|
|
||||||
auto compute = trivialComputes.front();
|
|
||||||
|
|
||||||
if (compute.use_empty()) {
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
|
||||||
|
|
||||||
auto newCompute =
|
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto weightMutableIter = newCompute.getWeightsMutable();
|
|
||||||
for (auto weight : child.getWeights()) {
|
|
||||||
auto founded = llvm::find(newCompute.getWeights(), weight);
|
|
||||||
if (founded == newCompute.getWeights().end()) {
|
|
||||||
weightMutableIter.append(weight);
|
|
||||||
auto last = weightMutableIter.end();
|
|
||||||
last = std::prev(last, 1);
|
|
||||||
mapper.map(weight, last->get());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
mapper.map(weight, *founded);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
|
||||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
|
||||||
newTerminator->erase();
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
|
||||||
for (auto& op : child.getBody().front()) {
|
|
||||||
auto newInst = rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
|
||||||
toErase.insert(child);
|
|
||||||
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
toErase.insert(compute);
|
|
||||||
|
|
||||||
if (newCompute->hasOneUse()) {
|
|
||||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
|
||||||
if (user && user.getInputs().size() == 1)
|
|
||||||
trivialComputes.push_back(newCompute);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto compute : toErase) {
|
|
||||||
compute.getResult(0).dropAllUses();
|
|
||||||
compute.erase();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
|
||||||
bool isAlwaysWeight =
|
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
|
||||||
if (isAlwaysWeight)
|
|
||||||
markWeightAlways(constantOp);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
|
||||||
|
|
||||||
for (auto compute : computes) {
|
|
||||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
|
||||||
bool needsRewrite = false;
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
|
||||||
continue;
|
|
||||||
promoteInput[inputIdx] = true;
|
|
||||||
needsRewrite = true;
|
|
||||||
}
|
|
||||||
if (!needsRewrite)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
|
||||||
|
|
||||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
|
||||||
SmallVector<Value> newInputs;
|
|
||||||
SmallVector<Type> newInputTypes;
|
|
||||||
SmallVector<Location> newInputLocs;
|
|
||||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
|
||||||
newInputs.reserve(compute.getInputs().size());
|
|
||||||
newInputTypes.reserve(compute.getInputs().size());
|
|
||||||
newInputLocs.reserve(compute.getInputs().size());
|
|
||||||
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
|
||||||
if (promoteInput[inputIdx]) {
|
|
||||||
newWeights.push_back(input);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
newInputs.push_back(input);
|
|
||||||
newInputTypes.push_back(input.getType());
|
|
||||||
newInputLocs.push_back(input.getLoc());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newCompute =
|
|
||||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
|
||||||
auto* newBlock =
|
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto& oldBlock = compute.getBody().front();
|
|
||||||
size_t newInputIdx = 0;
|
|
||||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
|
||||||
if (!promoteInput[oldInputIdx]) {
|
|
||||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
|
||||||
if (failed(clonedValue))
|
|
||||||
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
|
||||||
mapper.map(oldArg, *clonedValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& op : oldBlock.without_terminator())
|
|
||||||
rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
||||||
SmallVector<Value> newYieldOperands;
|
|
||||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
|
||||||
for (Value operand : oldYield.getOutputs()) {
|
|
||||||
auto mapped = mapper.lookupOrNull(operand);
|
|
||||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
|
||||||
}
|
|
||||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
|
||||||
|
|
||||||
compute.replaceAllUsesWith(newCompute);
|
|
||||||
compute.erase();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -7,11 +7,11 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -28,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
|
||||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
||||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
||||||
|
|
||||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
||||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||||
|
|
||||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -147,10 +137,11 @@ static Value buildPackedBias(bool hasBias,
|
|||||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createIm2colCompute(Value x,
|
static Value createIm2colRowComputes(Value x,
|
||||||
RankedTensorType xType,
|
RankedTensorType xType,
|
||||||
RankedTensorType im2colType,
|
RankedTensorType im2colType,
|
||||||
RankedTensorType rowType,
|
RankedTensorType im2colRowType,
|
||||||
|
RankedTensorType gemmInputRowsType,
|
||||||
int64_t batchSize,
|
int64_t batchSize,
|
||||||
int64_t numChannelsIn,
|
int64_t numChannelsIn,
|
||||||
int64_t xHeight,
|
int64_t xHeight,
|
||||||
@@ -169,11 +160,14 @@ static Value createIm2colCompute(Value x,
|
|||||||
int64_t patchSize,
|
int64_t patchSize,
|
||||||
int64_t numPatches,
|
int64_t numPatches,
|
||||||
int64_t numPatchesPerBatch,
|
int64_t numPatchesPerBatch,
|
||||||
|
int64_t packFactor,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
auto im2colComputeOp =
|
||||||
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
|
|
||||||
// Pad input with zeros if needed:
|
// Pad input with zeros if needed:
|
||||||
@@ -240,7 +234,7 @@ static Value createIm2colCompute(Value x,
|
|||||||
|
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rowType,
|
im2colRowType,
|
||||||
patch,
|
patch,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
@@ -256,28 +250,13 @@ static Value createIm2colCompute(Value x,
|
|||||||
|
|
||||||
rewriter.setInsertionPointAfter(im2colLoop);
|
rewriter.setInsertionPointAfter(im2colLoop);
|
||||||
Value im2col = im2colLoop.getResult(0);
|
Value im2col = im2colLoop.getResult(0);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
|
||||||
});
|
|
||||||
return im2colComputeOp.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createPackedIm2colRows(Value im2col,
|
Value gemmInputRows = im2col;
|
||||||
RankedTensorType im2colType,
|
if (packFactor != 1) {
|
||||||
Type elemType,
|
|
||||||
int64_t numPatches,
|
|
||||||
int64_t patchSize,
|
|
||||||
int64_t packFactor,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (packFactor == 1)
|
|
||||||
return im2col;
|
|
||||||
|
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||||
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||||
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
|
||||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
groupedType,
|
groupedType,
|
||||||
@@ -286,7 +265,7 @@ static Value createPackedIm2colRows(Value im2col,
|
|||||||
{0, 1},
|
{0, 1},
|
||||||
{2}
|
{2}
|
||||||
});
|
});
|
||||||
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
packedType,
|
packedType,
|
||||||
groupedIm2col,
|
groupedIm2col,
|
||||||
@@ -294,31 +273,39 @@ static Value createPackedIm2colRows(Value im2col,
|
|||||||
{0},
|
{0},
|
||||||
{1, 2}
|
{1, 2}
|
||||||
});
|
});
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
|
||||||
});
|
|
||||||
return packedComputeOp.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createUnpackedOutput(Value packedOutput,
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||||
|
});
|
||||||
|
|
||||||
|
return im2colComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||||
|
Type convType,
|
||||||
RankedTensorType gemmOutType,
|
RankedTensorType gemmOutType,
|
||||||
|
RankedTensorType nhwcType,
|
||||||
RankedTensorType outType,
|
RankedTensorType outType,
|
||||||
int64_t numPatches,
|
int64_t numPatches,
|
||||||
int64_t numChannelsOut,
|
int64_t numChannelsOut,
|
||||||
int64_t packFactor,
|
int64_t packFactor,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (packFactor == 1)
|
|
||||||
return packedOutput;
|
|
||||||
|
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||||
|
Value gemmOut;
|
||||||
|
if (packFactor == 1) {
|
||||||
|
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
|
}
|
||||||
|
else {
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
expandedType,
|
expandedType,
|
||||||
packedOutputArg,
|
packedOutput,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
{1, 2}
|
{1, 2}
|
||||||
@@ -332,30 +319,15 @@ static Value createUnpackedOutput(Value packedOutput,
|
|||||||
{2}
|
{2}
|
||||||
});
|
});
|
||||||
|
|
||||||
Value unpackedOutput = paddedOutput;
|
gemmOut = paddedOutput;
|
||||||
if (paddedNumPatches != numPatches) {
|
if (paddedNumPatches != numPatches) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
unpackedOutput =
|
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
|
||||||
});
|
|
||||||
return unpackComputeOp.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createCollectedConvOutput(Value gemmOut,
|
|
||||||
Type convType,
|
|
||||||
RankedTensorType nhwcType,
|
|
||||||
RankedTensorType outType,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto collectComputeOp =
|
|
||||||
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
|
||||||
Value gemmOutArg = gemmOutArgs.front();
|
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
// Restore to NCHW layout:
|
||||||
// [numPatches, numChannelsOut]
|
// [numPatches, numChannelsOut]
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
@@ -363,7 +335,7 @@ static Value createCollectedConvOutput(Value gemmOut,
|
|||||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
nhwcType,
|
nhwcType,
|
||||||
gemmOutArg,
|
gemmOut,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0, 1, 2},
|
{0, 1, 2},
|
||||||
{3}
|
{3}
|
||||||
@@ -374,6 +346,160 @@ static Value createCollectedConvOutput(Value gemmOut,
|
|||||||
return collectComputeOp.getResult(0);
|
return collectComputeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value lowerSingleConvGroup(Value x,
|
||||||
|
Value w,
|
||||||
|
Value b,
|
||||||
|
RankedTensorType xType,
|
||||||
|
RankedTensorType wType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t padHeightBegin,
|
||||||
|
int64_t padHeightEnd,
|
||||||
|
int64_t padWidthBegin,
|
||||||
|
int64_t padWidthEnd,
|
||||||
|
int64_t strideHeight,
|
||||||
|
int64_t strideWidth,
|
||||||
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||||
|
const int64_t wHeight = wType.getDimSize(2);
|
||||||
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
|
||||||
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// Gemm output: [numPatches, cOut]
|
||||||
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
|
auto elemType = xType.getElementType();
|
||||||
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||||
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||||
|
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||||
|
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||||
|
|
||||||
|
// Prepare weight matrix W for crossbar storage:
|
||||||
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
wFlatType,
|
||||||
|
w,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
|
||||||
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
Value biasMatrix;
|
||||||
|
DenseElementsAttr biasDenseAttr;
|
||||||
|
if (hasB) {
|
||||||
|
gemmBias = b;
|
||||||
|
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||||
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||||
|
}
|
||||||
|
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||||
|
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
||||||
|
const int64_t effectiveMaxParallelPixels =
|
||||||
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||||
|
|
||||||
|
// Keep the standard im2col view of convolution:
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||||
|
//
|
||||||
|
// We want to process N pixels at the same time. Instead of doing N separate operations
|
||||||
|
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||||
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||||
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||||
|
// B_packed: [N * patchSize, N * cOut]
|
||||||
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||||
|
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||||
|
auto gemmOutputRowsType =
|
||||||
|
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
|
Value gemmInputRows = createIm2colRowComputes(x,
|
||||||
|
xType,
|
||||||
|
im2colType,
|
||||||
|
rowType,
|
||||||
|
gemmInputRowsType,
|
||||||
|
batchSize,
|
||||||
|
numChannelsIn,
|
||||||
|
xHeight,
|
||||||
|
xWidth,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
outWidth,
|
||||||
|
patchSize,
|
||||||
|
numPatches,
|
||||||
|
numPatchesPerBatch,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
|
||||||
|
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||||
|
wTrans,
|
||||||
|
wType,
|
||||||
|
numChannelsIn,
|
||||||
|
numChannelsOut,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
patchSize,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
Value gemmC = buildPackedBias(
|
||||||
|
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
|
||||||
|
Value gemmRows = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmOutputRowsType,
|
||||||
|
gemmInputRows,
|
||||||
|
gemmB,
|
||||||
|
gemmC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
|
||||||
|
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||||
|
outType,
|
||||||
|
gemmOutType,
|
||||||
|
nhwcType,
|
||||||
|
outType,
|
||||||
|
numPatches,
|
||||||
|
numChannelsOut,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
@@ -388,11 +514,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
auto wType = cast<RankedTensorType>(w.getType());
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
if (!xType.hasStaticShape()) {
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||||
|
return failure();
|
||||||
// We need to understand what is group
|
}
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
if (!wType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (xType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (outType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (convOp.getGroup() < 1) {
|
||||||
|
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
@@ -403,12 +552,51 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const int64_t wWidth = wType.getDimSize(3);
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
const int64_t outHeight = outType.getDimSize(2);
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
const int64_t outWidth = outType.getDimSize(3);
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
const int64_t group = convOp.getGroup();
|
||||||
|
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
|
||||||
|
if (numChannelsIn % group != 0) {
|
||||||
|
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||||
|
<< " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (numChannelsOut % group != 0) {
|
||||||
|
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||||
|
<< " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||||
|
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||||
|
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||||
|
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||||
|
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getDimSize(0) != numChannelsOut) {
|
||||||
|
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||||
|
<< numChannelsOut << " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||||
const auto stridesAttr = convOp.getStrides();
|
const auto stridesAttr = convOp.getStrides();
|
||||||
const auto dilationsAttr = convOp.getDilations();
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
const auto padsAttr = convOp.getPads();
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
if (stridesAttr && stridesAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (padsAttr && padsAttr->size() != 4) {
|
||||||
|
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
@@ -449,67 +637,21 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
padWidthBegin = totalPadW - padWidthEnd;
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
if (group == 1) {
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
rewriter.replaceOp(convOp,
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
lowerSingleConvGroup(x,
|
||||||
// Gemm output: [numPatches, cOut]
|
|
||||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
|
||||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
|
||||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
|
||||||
|
|
||||||
auto elemType = xType.getElementType();
|
|
||||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
|
||||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
|
||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
|
||||||
|
|
||||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
|
||||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
|
||||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
|
||||||
auto wDenseAttr = getDenseConstantAttr(w);
|
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
wFlatType,
|
|
||||||
w,
|
w,
|
||||||
SmallVector<ReassociationIndices> {
|
b,
|
||||||
{0},
|
|
||||||
{1, 2, 3}
|
|
||||||
});
|
|
||||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
|
|
||||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
|
||||||
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
||||||
Value biasMatrix;
|
|
||||||
DenseElementsAttr biasDenseAttr;
|
|
||||||
if (hasB) {
|
|
||||||
gemmC = b;
|
|
||||||
biasDenseAttr = getDenseConstantAttr(b);
|
|
||||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
|
||||||
}
|
|
||||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
|
||||||
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
|
||||||
const int64_t effectiveMaxParallelPixels =
|
|
||||||
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
|
||||||
|
|
||||||
Value im2col = createIm2colCompute(x,
|
|
||||||
xType,
|
xType,
|
||||||
im2colType,
|
wType,
|
||||||
rowType,
|
outType,
|
||||||
batchSize,
|
|
||||||
numChannelsIn,
|
|
||||||
xHeight,
|
|
||||||
xWidth,
|
|
||||||
wHeight,
|
|
||||||
wWidth,
|
|
||||||
padHeightBegin,
|
padHeightBegin,
|
||||||
padHeightEnd,
|
padHeightEnd,
|
||||||
padWidthBegin,
|
padWidthBegin,
|
||||||
@@ -518,76 +660,74 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
strideWidth,
|
strideWidth,
|
||||||
dilationHeight,
|
dilationHeight,
|
||||||
dilationWidth,
|
dilationWidth,
|
||||||
outWidth,
|
|
||||||
patchSize,
|
|
||||||
numPatches,
|
|
||||||
numPatchesPerBatch,
|
|
||||||
rewriter,
|
rewriter,
|
||||||
loc);
|
loc));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
Value gemmOut;
|
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||||
if (effectiveMaxParallelPixels == 1) {
|
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||||
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
SmallVector<Value> bSlices;
|
||||||
gemmOut = ONNXGemmOp::create(rewriter,
|
if (hasB) {
|
||||||
loc,
|
auto biasType = cast<RankedTensorType>(b.getType());
|
||||||
gemmOutType,
|
int64_t biasAxis = -1;
|
||||||
im2col,
|
if (biasType.getRank() == 1)
|
||||||
wTrans,
|
biasAxis = 0;
|
||||||
gemmC,
|
else if (biasType.getRank() == 2)
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
else {
|
||||||
rewriter.getBoolAttr(false),
|
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||||
rewriter.getBoolAttr(false))
|
<< biasType.getRank();
|
||||||
.getY();
|
return failure();
|
||||||
|
}
|
||||||
|
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||||
|
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||||
|
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> groupResults;
|
||||||
|
groupResults.reserve(group);
|
||||||
|
auto groupOutType =
|
||||||
|
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||||
|
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||||
|
Value groupX = xSlices[groupId];
|
||||||
|
Value groupW = wSlices[groupId];
|
||||||
|
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||||
|
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||||
|
groupW,
|
||||||
|
groupB,
|
||||||
|
cast<RankedTensorType>(groupX.getType()),
|
||||||
|
cast<RankedTensorType>(groupW.getType()),
|
||||||
|
groupOutType,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
rewriter,
|
||||||
|
loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result;
|
||||||
|
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||||
|
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Keep the standard im2col view of convolution:
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
});
|
||||||
// but repack several old rows into one new row so we use the available crossbar size better.
|
result = concatCompute.getResult(0);
|
||||||
//
|
|
||||||
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
|
||||||
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
|
||||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
|
||||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
|
||||||
// B_packed: [N * patchSize, N * cOut]
|
|
||||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
|
||||||
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
|
||||||
auto packedOutType =
|
|
||||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
|
||||||
|
|
||||||
Value packedA = createPackedIm2colRows(
|
|
||||||
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
|
||||||
Value packedB = buildPackedWeight(wDenseAttr,
|
|
||||||
wTrans,
|
|
||||||
wType,
|
|
||||||
numChannelsIn,
|
|
||||||
numChannelsOut,
|
|
||||||
wHeight,
|
|
||||||
wWidth,
|
|
||||||
patchSize,
|
|
||||||
effectiveMaxParallelPixels,
|
|
||||||
rewriter,
|
|
||||||
loc);
|
|
||||||
Value packedC = buildPackedBias(
|
|
||||||
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
|
||||||
Value packedOut = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
packedOutType,
|
|
||||||
packedA,
|
|
||||||
packedB,
|
|
||||||
packedC,
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getBoolAttr(false),
|
|
||||||
rewriter.getBoolAttr(false))
|
|
||||||
.getY();
|
|
||||||
gemmOut = createUnpackedOutput(
|
|
||||||
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
rewriter.replaceOp(convOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -15,13 +16,6 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
|
||||||
strides[i] = strides[i + 1] * shape[i + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
|||||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value transposeForSpatial(Value value,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<int64_t> permutation,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
resultType,
|
||||||
|
value,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
resultType,
|
||||||
|
input,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -65,6 +105,72 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||||
|
RankedTensorType matrixType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t numRows = matrixType.getDimSize(0);
|
||||||
|
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||||
|
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||||
|
|
||||||
|
if (isHostFoldableValue(matrix)) {
|
||||||
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||||
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto buildRowSlices = [&](Value matrixArg) {
|
||||||
|
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||||
|
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cloneBatchInputChainIntoSliceCompute =
|
||||||
|
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
||||||
|
auto sliceCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
||||||
|
Value transformedMatrix = input;
|
||||||
|
if (!chainOps.empty()) {
|
||||||
|
IRMapping mapper;
|
||||||
|
mapper.map(rootValue, input);
|
||||||
|
for (Operation* chainOp : chainOps)
|
||||||
|
rewriter.clone(*chainOp, mapper);
|
||||||
|
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
||||||
|
});
|
||||||
|
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
||||||
|
return rowSlices;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<Operation*> chainOps;
|
||||||
|
Value rootValue = matrix;
|
||||||
|
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
||||||
|
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
||||||
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||||
|
return cloneBatchInputChainIntoSliceCompute(
|
||||||
|
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (definingOp->getNumOperands() != 1)
|
||||||
|
break;
|
||||||
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||||
|
break;
|
||||||
|
|
||||||
|
chainOps.push_back(definingOp);
|
||||||
|
rootValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||||
|
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
@@ -75,13 +181,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
Value b = gemmOpAdaptor.getB();
|
Value b = gemmOpAdaptor.getB();
|
||||||
Value c = gemmOpAdaptor.getC();
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = aType.getDimSize(0);
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
|
||||||
@@ -105,47 +221,43 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
if (cType.getRank() == 1) {
|
if (cType.getRank() == 1) {
|
||||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
c = tensor::ExpandShapeOp::create(rewriter,
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||||
loc,
|
|
||||||
expandedType,
|
|
||||||
c,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||||
|
SmallVector<Value> cSlices;
|
||||||
|
if (hasC && cHasNumOutRows)
|
||||||
|
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||||
|
|
||||||
SmallVector<Value> gemvOps;
|
SmallVector<Value> gemvOps;
|
||||||
gemvOps.reserve(numOutRows);
|
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
|
||||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
|
||||||
|
|
||||||
Value cSlice = c;
|
Value cSlice = c;
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
if (cHasNumOutRows) {
|
if (cHasNumOutRows)
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
else if (!isVectorShape(getTensorShape(c))) {
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
return failure();
|
||||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outRowType,
|
outRowType,
|
||||||
aSlice,
|
aSlices[static_cast<size_t>(rowIdx)],
|
||||||
b,
|
b,
|
||||||
cSlice,
|
cSlice,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
@@ -156,8 +268,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
@@ -189,20 +300,31 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
if (cType.getRank() == 1) {
|
if (cType.getRank() == 1) {
|
||||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
c = tensor::ExpandShapeOp::create(rewriter,
|
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||||
gemmLoc,
|
|
||||||
expandedType,
|
|
||||||
c,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
cType = expandedType;
|
cType = expandedType;
|
||||||
}
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
if (!aType.hasStaticShape()) {
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||||
// Not a gemv
|
// Not a gemv
|
||||||
@@ -210,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
if (transA) {
|
if (transA) {
|
||||||
auto aShape = aType.getShape();
|
auto aShape = aType.getShape();
|
||||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||||
|
aType = cast<RankedTensorType>(a.getType());
|
||||||
}
|
}
|
||||||
if (transB) {
|
if (transB) {
|
||||||
auto bShape = bType.getShape();
|
auto bShape = bType.getShape();
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||||
auto bNumVSlices = aNumHSlices;
|
auto bNumVSlices = aNumHSlices;
|
||||||
auto bLastVSliceSize = aLastHSliceSize;
|
|
||||||
auto cNumHSlices = bNumHSlices;
|
auto cNumHSlices = bNumHSlices;
|
||||||
auto cLastHSliceSize = bLastHSliceSize;
|
auto cLastHSliceSize = bLastHSliceSize;
|
||||||
auto outNumHSlices = cNumHSlices;
|
auto outNumHSlices = cNumHSlices;
|
||||||
@@ -280,20 +402,39 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp = createSpatCompute(
|
auto computeOp =
|
||||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||||
|
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||||
|
for (Value weight : weights) {
|
||||||
|
blockArgTypes.push_back(weight.getType());
|
||||||
|
blockArgLocs.push_back(gemmLoc);
|
||||||
|
}
|
||||||
|
for (Value input : aHSlices[coreId]) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(gemmLoc);
|
||||||
|
}
|
||||||
|
Block* body =
|
||||||
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
SmallVector<Value> vmmOutputs;
|
SmallVector<Value> vmmOutputs;
|
||||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||||
vmmOutputs.push_back(
|
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
||||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
if (vmmOutputs.empty()) {
|
||||||
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
});
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
partialResults.push_back(computeOp->getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
@@ -313,15 +454,141 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = gemmOp.getLoc();
|
||||||
|
Value a = gemmOpAdaptor.getA();
|
||||||
|
Value b = gemmOpAdaptor.getB();
|
||||||
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
|
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
|
auto bType = cast<RankedTensorType>(b.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
|
if (!aType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!bType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
if (numOutRows <= 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
||||||
|
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
||||||
|
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledB))
|
||||||
|
return failure();
|
||||||
|
b = *scaledB;
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
|
||||||
|
if (gemmOpAdaptor.getTransB()) {
|
||||||
|
auto bShape = bType.getShape();
|
||||||
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||||
|
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
}
|
||||||
|
(void) bType;
|
||||||
|
|
||||||
|
Value sharedBias;
|
||||||
|
if (hasC) {
|
||||||
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledC))
|
||||||
|
return failure();
|
||||||
|
c = *scaledC;
|
||||||
|
auto cType = cast<RankedTensorType>(c.getType());
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
}
|
||||||
|
if (!cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (cType.getRank() != 2) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||||
|
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||||
|
return failure();
|
||||||
|
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||||
|
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
||||||
|
sharedBias = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||||
|
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {outType},
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||||
|
ValueRange {b},
|
||||||
|
ValueRange {a});
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||||
|
SmallVector<Location> blockArgLocs(4, loc);
|
||||||
|
Block* body =
|
||||||
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
|
Value lane = batchOp.getLaneArgument();
|
||||||
|
Value weight = batchOp.getWeightArgument(0);
|
||||||
|
Value packedInput = batchOp.getInputArgument(0);
|
||||||
|
Value packedOutput = batchOp.getOutputArgument(0);
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value row =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||||
|
Value laneResult = vmmResult;
|
||||||
|
if (sharedBias)
|
||||||
|
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||||
|
|
||||||
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||||
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||||
|
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
||||||
|
unitStrides);
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
||||||
patterns.insert<GemmToManyGemv>(ctx);
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,15 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include <functional>
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -14,7 +19,191 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||||
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||||
|
ArrayRef<int64_t> rhsBatchShape) {
|
||||||
|
if (lhsBatchShape.empty())
|
||||||
|
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||||
|
if (rhsBatchShape.empty())
|
||||||
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
|
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||||
|
return failure();
|
||||||
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value collapseBatchDims(Value value,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t cols,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 2 || type.getRank() == 3)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto collapsedType =
|
||||||
|
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||||
|
SmallVector<ReassociationIndices> reassociation = {
|
||||||
|
ReassociationIndices {},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||||
|
};
|
||||||
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||||
|
reassociation.front().push_back(dim);
|
||||||
|
|
||||||
|
auto buildCollapsed = [&](Value input) -> Value {
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildCollapsed(value);
|
||||||
|
|
||||||
|
auto collapseCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||||
|
});
|
||||||
|
return collapseCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value expandBatchDims(Value value,
|
||||||
|
RankedTensorType outputType,
|
||||||
|
size_t batchRank,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation = {
|
||||||
|
ReassociationIndices {},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||||
|
};
|
||||||
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
|
||||||
|
auto expandCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
return expandCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchMatrix(Value value,
|
||||||
|
int64_t batchIndex,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t cols,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 2)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||||
|
auto buildMatrix = [&](Value input) -> Value {
|
||||||
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
matrixType,
|
||||||
|
slice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildMatrix(value);
|
||||||
|
|
||||||
|
auto batchMatrixCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||||
|
});
|
||||||
|
return batchMatrixCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
RankedTensorType transposedType;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
perm = {1, 0};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
perm = {0, 2, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto buildTranspose = [&](Value input) -> Value {
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildTranspose(value);
|
||||||
|
|
||||||
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
RankedTensorType transposedType;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
perm = {1, 0};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
perm = {0, 2, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||||
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
for (Value input : inputs)
|
||||||
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||||
|
outputShape[axis] = concatDimSize;
|
||||||
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
|
||||||
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||||
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||||
|
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||||
|
});
|
||||||
|
return concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
@@ -24,80 +213,129 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||||
|
return failure();
|
||||||
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||||
|
|| !haveStaticPositiveShape(outType.getShape()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t batch = rhsType.getDimSize(0);
|
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||||
const int64_t k = rhsType.getDimSize(1);
|
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||||
const int64_t n = rhsType.getDimSize(2);
|
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||||
const int64_t m = lhsType.getDimSize(0);
|
if (failed(batchShape))
|
||||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
|
||||||
|| outType.getDimSize(2) != n)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||||
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||||
|
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||||
|
|
||||||
|
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||||
|
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||||
|
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||||
|
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||||
|
if (k != rhsK)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (outType.getRank() == 2) {
|
||||||
|
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||||
|
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||||
|
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
|
||||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
|
||||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
|
||||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
|
||||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
|
||||||
|
|
||||||
Value lhsTransposed =
|
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||||
|
int64_t lhsBatchForGemm = lhsBatch;
|
||||||
|
int64_t rhsBatchForGemm = rhsBatch;
|
||||||
|
int64_t gemmM = m;
|
||||||
|
int64_t gemmK = k;
|
||||||
|
int64_t gemmN = n;
|
||||||
|
if (useTransposedForm) {
|
||||||
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||||
|
lhsBatchForGemm = rhsBatch;
|
||||||
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
|
rhsBatchForGemm = lhsBatch;
|
||||||
|
gemmM = n;
|
||||||
|
gemmN = m;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||||
|
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
SmallVector<Value> gemmRows;
|
if (outType.getRank() == 2) {
|
||||||
gemmRows.reserve(batch * n);
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
SmallVector<OpFoldResult> offsets = {
|
|
||||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
|
||||||
SmallVector<OpFoldResult> strides = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Value rhsSlice =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
|
||||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
loc,
|
||||||
rhsRowType,
|
gemmType,
|
||||||
rhsSlice,
|
lhsMatrix,
|
||||||
SmallVector<ReassociationIndices> {
|
rhsMatrix,
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
|
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
gemmRowType,
|
|
||||||
rhsRow,
|
|
||||||
lhsTransposed,
|
|
||||||
none,
|
none,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getBoolAttr(false),
|
rewriter.getBoolAttr(false),
|
||||||
rewriter.getBoolAttr(false));
|
rewriter.getBoolAttr(false))
|
||||||
gemmRows.push_back(gemmOp.getY());
|
.getY();
|
||||||
}
|
if (useTransposedForm) {
|
||||||
}
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
});
|
});
|
||||||
|
gemmResult = transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
Value gemmOut = concatComputeOp.getResult(0);
|
SmallVector<Value> batchResults;
|
||||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
batchResults.reserve(batch);
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
|
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
|
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
gemmExpandedType,
|
gemmType,
|
||||||
gemmOut,
|
lhsMatrix,
|
||||||
|
rhsMatrix,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
auto batchResultCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||||
|
Value resultMatrix = input;
|
||||||
|
if (useTransposedForm) {
|
||||||
|
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||||
|
input,
|
||||||
|
rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
}
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
batchedOutType,
|
||||||
|
resultMatrix,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0, 1},
|
{0, 1},
|
||||||
{2}
|
{2}
|
||||||
});
|
});
|
||||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
batchResults.push_back(batchResultCompute.getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||||
|
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -106,7 +344,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
patterns.insert<MatMulToGemm>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||||
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
|
int64_t concatDimSize = 0;
|
||||||
|
for (Value input : inputs)
|
||||||
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||||
|
outputShape[axis] = concatDimSize;
|
||||||
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||||
|
|
||||||
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||||
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||||
|
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||||
|
});
|
||||||
|
return concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static Value buildReduceMeanKeepdims(Value input,
|
static Value buildReduceMeanKeepdims(Value input,
|
||||||
ArrayRef<bool> reducedAxes,
|
ArrayRef<bool> reducedAxes,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
@@ -100,8 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
|||||||
for (Value slice : slices)
|
for (Value slice : slices)
|
||||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||||
|
|
||||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||||
@@ -116,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
|||||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensor::CollapseShapeOp::create(
|
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
if (isHostFoldableValue(keepdimsValue))
|
||||||
.getResult();
|
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||||
|
|
||||||
|
auto squeezeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
|
||||||
|
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
|
||||||
|
});
|
||||||
|
return squeezeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/APFloat.h"
|
||||||
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -31,13 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
|||||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
|
||||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
|
||||||
if (values.size() == 1)
|
|
||||||
return values.front();
|
|
||||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
@@ -52,27 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
|||||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ReduceOp>
|
static Value
|
||||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
if (!useMinimumValue)
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||||
|
|
||||||
Value reduced = windowValues.front();
|
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||||
for (Value value : windowValues.drop_front())
|
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||||
return reduced;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||||
if (divisor == 1)
|
}
|
||||||
return reducedWindow;
|
|
||||||
|
|
||||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
llvm_unreachable("unsupported pool element type");
|
||||||
double scale = 1.0 / static_cast<double>(divisor);
|
}
|
||||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
|
||||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
|
||||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
Location loc,
|
||||||
|
RankedTensorType tensorType,
|
||||||
|
bool useMinimumValue) {
|
||||||
|
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||||
|
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PoolOp>
|
||||||
|
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
PoolOp poolOp,
|
||||||
|
Value input,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
int64_t padTop,
|
||||||
|
int64_t padLeft,
|
||||||
|
int64_t padBottom,
|
||||||
|
int64_t padRight) {
|
||||||
|
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
|
||||||
|
return input;
|
||||||
|
|
||||||
|
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
|
||||||
|
inputType.getDimSize(1),
|
||||||
|
inputType.getDimSize(2) + padTop + padBottom,
|
||||||
|
inputType.getDimSize(3) + padLeft + padRight},
|
||||||
|
inputType.getElementType(),
|
||||||
|
inputType.getEncoding());
|
||||||
|
SmallVector<OpFoldResult> lowPads = {
|
||||||
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padBottom),
|
||||||
|
rewriter.getIndexAttr(padRight)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int index = 0; index < paddedType.getRank(); ++index)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
Value padValue =
|
||||||
|
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
|
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
return padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Operation* op,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t channels,
|
||||||
|
int64_t inputHeight,
|
||||||
|
int64_t inputWidth,
|
||||||
|
int64_t outputHeight,
|
||||||
|
int64_t outputWidth,
|
||||||
|
int64_t kernelHeight,
|
||||||
|
int64_t kernelWidth,
|
||||||
|
int64_t strideHeight,
|
||||||
|
int64_t strideWidth,
|
||||||
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
|
int64_t padTop,
|
||||||
|
int64_t padLeft,
|
||||||
|
bool countIncludePad) {
|
||||||
|
auto elemType = dyn_cast<FloatType>(outType.getElementType());
|
||||||
|
if (!elemType) {
|
||||||
|
op->emitOpError("AveragePool lowering requires a floating-point element type");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
|
||||||
|
SmallVector<Attribute> scaleValues;
|
||||||
|
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
|
||||||
|
for (int64_t channel = 0; channel < channels; ++channel) {
|
||||||
|
(void) channel;
|
||||||
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||||
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||||
|
int64_t validCount = 0;
|
||||||
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||||
|
if (inH < 0 || inH >= inputHeight)
|
||||||
|
continue;
|
||||||
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||||
|
if (inW < 0 || inW >= inputWidth)
|
||||||
|
continue;
|
||||||
|
++validCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
|
||||||
|
if (divisor <= 0) {
|
||||||
|
op->emitOpError("AveragePool divisor must be positive");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PoolOp>
|
template <typename PoolOp>
|
||||||
@@ -150,89 +244,133 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) padBottom;
|
|
||||||
(void) padRight;
|
|
||||||
|
|
||||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||||
|
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
|
||||||
|
const bool countIncludePad = [&]() {
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
|
||||||
|
return poolOp.getCountIncludePad() == 1;
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
Value averageScaleTensor;
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
|
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
|
||||||
|
loc,
|
||||||
|
poolOp,
|
||||||
|
outType,
|
||||||
|
channels,
|
||||||
|
inputHeight,
|
||||||
|
inputWidth,
|
||||||
|
outputHeight,
|
||||||
|
outputWidth,
|
||||||
|
kernelHeight,
|
||||||
|
kernelWidth,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
padTop,
|
||||||
|
padLeft,
|
||||||
|
countIncludePad);
|
||||||
|
if (failed(maybeAverageScaleTensor))
|
||||||
|
return failure();
|
||||||
|
averageScaleTensor = *maybeAverageScaleTensor;
|
||||||
|
}
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||||
SmallVector<Value> batchResults;
|
Value paddedInput =
|
||||||
batchResults.reserve(batchSize);
|
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||||
|
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||||
|
|
||||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
SmallVector<Value> rows;
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
rows.reserve(outputHeight);
|
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||||
|
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||||
|
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||||
|
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||||
|
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||||
|
|
||||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||||
SmallVector<Value> rowPixels;
|
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||||
rowPixels.reserve(outputWidth);
|
|
||||||
|
|
||||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
Value outputPatchIndex = outputLoop.getInductionVar();
|
||||||
SmallVector<Value> outputChannelTiles;
|
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
|
||||||
outputChannelTiles.reserve(channelTileCount);
|
|
||||||
|
|
||||||
|
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||||
|
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||||
|
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||||
|
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||||
|
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||||
|
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||||
|
|
||||||
|
Value updatedOutput = pooledOutputAcc;
|
||||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
|
Value reducedWindow =
|
||||||
|
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
|
|
||||||
SmallVector<Value> windowValues;
|
|
||||||
windowValues.reserve(kernelHeight * kernelWidth);
|
|
||||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
Value paddedInH = windowBaseH;
|
||||||
if (inH < 0 || inH >= inputHeight)
|
if (kernelH * dilationHeight != 0) {
|
||||||
continue;
|
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||||
|
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
Value paddedInW = windowBaseW;
|
||||||
if (inW < 0 || inW >= inputWidth)
|
if (kernelW * dilationWidth != 0) {
|
||||||
continue;
|
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||||
|
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
SmallVector<OpFoldResult> offsets = {
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||||
rewriter.getIndexAttr(inH),
|
|
||||||
rewriter.getIndexAttr(inW)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(tileChannels),
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> strides = {
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value windowValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||||
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
|
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
|
SmallVector<OpFoldResult> scaleOffsets = {
|
||||||
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||||
|
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
Value windowValue =
|
SmallVector<OpFoldResult> scaleStrides = {
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||||
windowValues.push_back(windowValue);
|
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||||
}
|
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||||
|
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (windowValues.empty())
|
SmallVector<OpFoldResult> outputOffsets = {
|
||||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||||
|
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
rewriter.getIndexAttr(tileChannels),
|
||||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
rewriter.getIndexAttr(1),
|
||||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
rewriter.getIndexAttr(1)};
|
||||||
const int64_t divisor =
|
SmallVector<OpFoldResult> outputStrides = {
|
||||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
updatedOutput = tensor::InsertSliceOp::create(
|
||||||
|
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||||
}
|
}
|
||||||
|
|
||||||
outputChannelTiles.push_back(reducedWindow);
|
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||||
}
|
|
||||||
|
|
||||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
rewriter.setInsertionPointAfter(outputLoop);
|
||||||
}
|
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
|
||||||
|
|
||||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
|
||||||
}
|
|
||||||
|
|
||||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
if (failed(computeOp))
|
if (failed(computeOp))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -21,34 +23,81 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
|||||||
return permutedShape;
|
return permutedShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value buildLoopSoftmaxSlice(Value input,
|
||||||
|
Value accumulator,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
ArrayRef<Value> outerIndices,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
int64_t rank = inputType.getRank();
|
||||||
|
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||||
|
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||||
|
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||||
|
offsets.reserve(rank);
|
||||||
|
sizes.reserve(rank);
|
||||||
|
|
||||||
|
for (Value outerIndex : outerIndices) {
|
||||||
|
offsets.push_back(outerIndex);
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||||
|
|
||||||
|
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||||
|
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||||
|
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildLoopSoftmaxNest(Value input,
|
||||||
|
Value accumulator,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
int64_t axis,
|
||||||
|
SmallVectorImpl<Value>& outerIndices,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (axis == inputType.getRank() - 1)
|
||||||
|
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||||
|
|
||||||
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
|
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value loopIndex = loop.getInductionVar();
|
||||||
|
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||||
|
outerIndices.push_back(loopIndex);
|
||||||
|
Value updatedAccumulator =
|
||||||
|
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||||
|
outerIndices.pop_back();
|
||||||
|
|
||||||
|
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
return loop.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
if (inputType.getRank() == 1) {
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||||
});
|
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||||
return computeOp.getResult(0);
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
SmallVector<Value> outerIndices;
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||||
if (axis == inputType.getRank())
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||||
return createSoftmaxCompute(input, rewriter, loc);
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
if (axis == softmaxAxis)
|
|
||||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
|
||||||
SmallVector<Value> rebuiltSlices;
|
|
||||||
rebuiltSlices.reserve(slices.size());
|
|
||||||
for (Value slice : slices)
|
|
||||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
|
||||||
|
|
||||||
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
|
||||||
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||||
@@ -68,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
Value result;
|
Value result;
|
||||||
if (axis == inputType.getRank() - 1) {
|
if (axis == inputType.getRank() - 1) {
|
||||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
SmallVector<int64_t> permutation;
|
SmallVector<int64_t> permutation;
|
||||||
@@ -91,10 +140,14 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||||
});
|
});
|
||||||
Value transposedInput = preTransposeCompute.getResult(0);
|
Value transposedInput = preTransposeCompute.getResult(0);
|
||||||
Value transposedResult = buildSoftmax(
|
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
auto postTransposeCompute =
|
||||||
result = ONNXTransposeOp::create(
|
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
Value transposed = ONNXTransposeOp::create(
|
||||||
|
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||||
|
});
|
||||||
|
result = postTransposeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(softmaxOp, result);
|
rewriter.replaceOp(softmaxOp, result);
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -17,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
auto inputs = adaptor.getInputs();
|
auto inputs = adaptor.getInputs();
|
||||||
int64_t axis = adaptor.getAxis();
|
int64_t axis = adaptor.getAxis();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||||
|
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute(
|
||||||
|
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(
|
||||||
|
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
|
|||||||
}
|
}
|
||||||
if (slices.empty())
|
if (slices.empty())
|
||||||
return {};
|
return {};
|
||||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
return createSpatConcat(rewriter, loc, axis, slices);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||||
}
|
}
|
||||||
result = rows.size() == 1
|
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
||||||
? rows.front()
|
|
||||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -77,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
|||||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||||
|
SmallVector<ReassociationIndices> reassociation(1);
|
||||||
|
reassociation.front().reserve(rank);
|
||||||
|
for (size_t dim = 0; dim < rank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
return reassociation;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||||
|
SmallVector<ReassociationIndices> reassociation(1);
|
||||||
|
reassociation.front().reserve(rank);
|
||||||
|
for (size_t dim = 0; dim < rank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
return reassociation;
|
||||||
|
}
|
||||||
|
|
||||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -95,18 +114,50 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<ReassociationIndices> reassociation;
|
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||||
if (sourceType.getRank() > resultType.getRank()
|
if (isHostFoldableValue(adaptor.getData())) {
|
||||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sourceType.getRank() < resultType.getRank()
|
auto computeOp = createSpatCompute<1>(
|
||||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
Value reshaped = buildReshape(data);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
|
if (sourceType.getRank() > resultType.getRank()
|
||||||
|
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||||
|
return replaceWithReshape([&](Value data) {
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sourceType.getRank() < resultType.getRank()
|
||||||
|
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||||
|
return replaceWithReshape([&](Value data) {
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return replaceWithReshape([&](Value data) -> Value {
|
||||||
|
Value reshaped = data;
|
||||||
|
if (sourceType.getRank() != 1) {
|
||||||
|
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||||
|
reshaped = tensor::CollapseShapeOp::create(
|
||||||
|
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||||
}
|
}
|
||||||
|
if (resultType.getRank() == 1)
|
||||||
|
return reshaped;
|
||||||
|
return tensor::ExpandShapeOp::create(
|
||||||
|
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
});
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user