Compare commits
40 Commits
143c8f960a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a060f455b | ||
|
|
742df111e3 | ||
|
|
4e19650b80 | ||
|
|
ed359730f1 | ||
|
|
a4f3eed3e0 | ||
|
|
93e20c1dfc | ||
|
|
0478d979ff | ||
|
|
da01e6d697 | ||
|
|
f869925b64 | ||
|
|
f2d593f749 | ||
|
|
a1b29dffe0 | ||
|
|
661170a9aa | ||
|
|
461bdd808d | ||
|
|
50c545539b | ||
|
|
11916a2595 | ||
|
|
670d6ce94f | ||
|
|
2676f2c7ef | ||
|
|
45342190bb | ||
|
|
f629e0d99f | ||
|
|
568529ea5f | ||
|
|
ca2e1645bb | ||
|
|
6933804003 | ||
|
|
dbe646ac0d | ||
|
|
bb6dcd38a3 | ||
|
|
916a09414c | ||
|
|
db3f52a647 | ||
|
|
6e1de865bb | ||
|
|
4e50e056e3 | ||
|
|
771b44a2ed | ||
|
|
7ce1d2b34d | ||
|
|
584ca0b3c2 | ||
|
|
1348bb1c97 | ||
|
|
825188cc89 | ||
|
|
7202a4317d | ||
|
|
d4efa64b96 | ||
|
|
fef26cee9a | ||
|
|
29febb2bfd | ||
|
|
f24a60bfcd | ||
|
|
91ef6d9bc3 | ||
|
|
8ee1e5ece8 |
54
.github/actions/prepare-mlir-cache/action.yml
vendored
Normal file
54
.github/actions/prepare-mlir-cache/action.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
name: Prepare MLIR Cache
|
||||||
|
description: Restore or build the cached MLIR toolchain.
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
llvm-commit:
|
||||||
|
description: LLVM commit to build.
|
||||||
|
required: true
|
||||||
|
mold-linker-flags:
|
||||||
|
description: Linker flags used to force mold.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Restore MLIR Cache
|
||||||
|
id: restore-mlir-cache
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: onnx-mlir/llvm-project
|
||||||
|
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
||||||
|
|
||||||
|
- name: Clone LLVM
|
||||||
|
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
|
||||||
|
cd onnx-mlir/llvm-project
|
||||||
|
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
|
- name: Build MLIR
|
||||||
|
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cmake -S onnx-mlir/llvm-project/llvm -B onnx-mlir/llvm-project/build -G Ninja \
|
||||||
|
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
|
||||||
|
-DLLVM_ENABLE_RUNTIMES="openmp" \
|
||||||
|
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
|
-DLLVM_ENABLE_RTTI=ON \
|
||||||
|
-DENABLE_LIBOMPTARGET=OFF \
|
||||||
|
-DLLVM_ENABLE_LIBEDIT=OFF \
|
||||||
|
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||||
|
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
|
||||||
|
cmake --build onnx-mlir/llvm-project/build
|
||||||
|
|
||||||
|
- name: Save MLIR Cache
|
||||||
|
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: onnx-mlir/llvm-project
|
||||||
|
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
||||||
64
.github/actions/prepare-protobuf-cache/action.yml
vendored
Normal file
64
.github/actions/prepare-protobuf-cache/action.yml
vendored
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
name: Prepare Protobuf Cache
|
||||||
|
description: Restore or build the cached Protobuf installation.
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
protobuf-ref:
|
||||||
|
description: Protobuf tag or commit to build.
|
||||||
|
required: true
|
||||||
|
mold-linker-flags:
|
||||||
|
description: Linker flags used to force mold.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Restore Protobuf Cache
|
||||||
|
id: restore-protobuf-cache
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/usr/local/lib/libproto*
|
||||||
|
/usr/local/lib/libabsl*
|
||||||
|
/usr/local/lib/libutf8*
|
||||||
|
/usr/local/lib/cmake/protobuf
|
||||||
|
/usr/local/lib/cmake/absl
|
||||||
|
/usr/local/lib/cmake/utf8_range
|
||||||
|
/usr/local/include/google/protobuf
|
||||||
|
/usr/local/include/absl
|
||||||
|
/usr/local/include/utf8_range.h
|
||||||
|
/usr/local/bin/protoc*
|
||||||
|
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-ref }}-v2
|
||||||
|
restore-keys: |
|
||||||
|
protobuf-${{ runner.os }}-${{ inputs.protobuf-ref }}-v2
|
||||||
|
|
||||||
|
- name: Install Protobuf
|
||||||
|
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
git clone --depth 1 --branch ${{ inputs.protobuf-ref }} https://github.com/protocolbuffers/protobuf protobuf
|
||||||
|
cmake -S protobuf -B protobuf/build -G Ninja \
|
||||||
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-Dprotobuf_BUILD_TESTS=OFF \
|
||||||
|
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||||
|
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
|
||||||
|
cmake --build protobuf/build
|
||||||
|
sudo cmake --install protobuf/build
|
||||||
|
rm -rf protobuf
|
||||||
|
|
||||||
|
- name: Save Protobuf Cache
|
||||||
|
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/usr/local/lib/libproto*
|
||||||
|
/usr/local/lib/libabsl*
|
||||||
|
/usr/local/lib/libutf8*
|
||||||
|
/usr/local/lib/cmake/protobuf
|
||||||
|
/usr/local/lib/cmake/absl
|
||||||
|
/usr/local/lib/cmake/utf8_range
|
||||||
|
/usr/local/include/google/protobuf
|
||||||
|
/usr/local/include/absl
|
||||||
|
/usr/local/include/utf8_range.h
|
||||||
|
/usr/local/bin/protoc*
|
||||||
|
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-ref }}-v2
|
||||||
26
.github/actions/restore-raptor-build-cache/action.yml
vendored
Normal file
26
.github/actions/restore-raptor-build-cache/action.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: Restore Raptor Build Cache
|
||||||
|
description: Restore the cached Raptor build directory for incremental builds.
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
key:
|
||||||
|
description: Exact cache key to restore.
|
||||||
|
required: true
|
||||||
|
restore-keys:
|
||||||
|
description: Prefixes used to restore the most recent compatible cache.
|
||||||
|
required: false
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
cache-hit:
|
||||||
|
description: Whether the exact cache key was restored.
|
||||||
|
value: ${{ steps.restore-raptor-build-cache.outputs.cache-hit }}
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Restore Raptor Build Cache
|
||||||
|
id: restore-raptor-build-cache
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: build
|
||||||
|
key: ${{ inputs.key }}
|
||||||
|
restore-keys: ${{ inputs.restore-keys }}
|
||||||
16
.github/actions/save-raptor-build-cache/action.yml
vendored
Normal file
16
.github/actions/save-raptor-build-cache/action.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
name: Save Raptor Build Cache
|
||||||
|
description: Save the Raptor build directory after a successful incremental build.
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
key:
|
||||||
|
description: Cache key used to save the build directory.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: composite
|
||||||
|
steps:
|
||||||
|
- name: Save Raptor Build Cache
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: build
|
||||||
|
key: ${{ inputs.key }}
|
||||||
50
.github/workflows/build_mlir_cache.yml
vendored
50
.github/workflows/build_mlir_cache.yml
vendored
@@ -1,50 +0,0 @@
|
|||||||
name: Build MLIR Cache
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
llvm-commit:
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build-mlir:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Cache MLIR build
|
|
||||||
id: cache-mlir
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: onnx-mlir/llvm-project
|
|
||||||
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
|
||||||
sudo apt update
|
|
||||||
sudo apt install -y cmake ninja-build build-essential
|
|
||||||
|
|
||||||
- name: Clone LLVM
|
|
||||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
|
||||||
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
|
|
||||||
cd onnx-mlir/llvm-project
|
|
||||||
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
|
|
||||||
git checkout FETCH_HEAD
|
|
||||||
|
|
||||||
- name: Build MLIR
|
|
||||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
|
||||||
mkdir -p onnx-mlir/llvm-project/build
|
|
||||||
cd onnx-mlir/llvm-project/build
|
|
||||||
cmake -G Ninja ../llvm \
|
|
||||||
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
|
|
||||||
-DLLVM_ENABLE_RUNTIMES="openmp" \
|
|
||||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
|
||||||
-DCMAKE_BUILD_TYPE=Release \
|
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
|
||||||
-DLLVM_ENABLE_RTTI=ON \
|
|
||||||
-DENABLE_LIBOMPTARGET=OFF \
|
|
||||||
-DLLVM_ENABLE_LIBEDIT=OFF
|
|
||||||
cmake --build .
|
|
||||||
109
.github/workflows/validate_operations.yml
vendored
109
.github/workflows/validate_operations.yml
vendored
@@ -2,64 +2,58 @@ name: Validate Operations
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
|
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
|
||||||
PROTOBUF_COMMIT: v34.0
|
PROTOBUF_REF: v34.0
|
||||||
|
CMAKE_VERSION: 4.3.0
|
||||||
|
MOLD_LINKER_FLAGS: -fuse-ld=mold
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# Expose env vars as outputs so they can be passed to reusable workflows
|
validate-operations:
|
||||||
config:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
llvm-commit: ${{ steps.set-vars.outputs.llvm-commit }}
|
|
||||||
steps:
|
|
||||||
- id: set-vars
|
|
||||||
run: echo "llvm-commit=$LLVM_COMMIT" >> "$GITHUB_OUTPUT"
|
|
||||||
|
|
||||||
build-mlir-cache:
|
|
||||||
needs: config
|
|
||||||
uses: ./.github/workflows/build_mlir_cache.yml
|
|
||||||
with:
|
|
||||||
llvm-commit: ${{ needs.config.outputs.llvm-commit }}
|
|
||||||
|
|
||||||
validate:
|
|
||||||
needs: build-mlir-cache
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
run: |
|
||||||
with:
|
git clone --depth 1 --recurse-submodules --branch ${GITHUB_REF_NAME} \
|
||||||
submodules: recursive
|
https://chef.heaplab.deib.polimi.it/git/${GITHUB_REPOSITORY}.git \
|
||||||
|
${GITHUB_WORKSPACE}
|
||||||
|
|
||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y cmake ninja-build build-essential
|
sudo apt install -y cmake ninja-build build-essential mold curl ca-certificates
|
||||||
|
|
||||||
- name: Cache protobuf build
|
- name: Install CMake
|
||||||
id: cache-protobuf
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/usr/local/lib/libproto*
|
|
||||||
/usr/local/include/google/protobuf
|
|
||||||
key: protobuf-${{ runner.os }}-${{ env.PROTOBUF_COMMIT }}
|
|
||||||
|
|
||||||
- name: Install protobuf
|
|
||||||
if: steps.cache-protobuf.outputs.cache-hit != 'true'
|
|
||||||
run: |
|
run: |
|
||||||
git clone --depth 1 --branch ${{ env.PROTOBUF_COMMIT }} https://github.com/protocolbuffers/protobuf
|
ARCH="$(uname -m)"
|
||||||
cd protobuf
|
case "$ARCH" in
|
||||||
mkdir build
|
x86_64) CMAKE_ARCH="linux-x86_64" ;;
|
||||||
cd build
|
aarch64|arm64) CMAKE_ARCH="linux-aarch64" ;;
|
||||||
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
|
*) echo "Unsupported architecture: $ARCH"; exit 1 ;;
|
||||||
ninja
|
esac
|
||||||
sudo ninja install
|
|
||||||
cd ../..
|
curl -fsSL -o /tmp/cmake.sh \
|
||||||
rm -rf protobuf
|
"https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-${CMAKE_ARCH}.sh"
|
||||||
|
sudo sh /tmp/cmake.sh --skip-license --prefix=/usr/local
|
||||||
|
rm -f /tmp/cmake.sh
|
||||||
|
|
||||||
|
cmake --version
|
||||||
|
which cmake
|
||||||
|
|
||||||
|
- name: Prepare MLIR Cache
|
||||||
|
uses: ./.github/actions/prepare-mlir-cache
|
||||||
|
with:
|
||||||
|
llvm-commit: ${{ env.LLVM_COMMIT }}
|
||||||
|
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
|
||||||
|
|
||||||
|
- name: Prepare Protobuf Cache
|
||||||
|
uses: ./.github/actions/prepare-protobuf-cache
|
||||||
|
with:
|
||||||
|
protobuf-ref: ${{ env.PROTOBUF_REF }}
|
||||||
|
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
|
||||||
|
|
||||||
- name: Register installed libraries
|
- name: Register installed libraries
|
||||||
run: sudo ldconfig
|
run: sudo ldconfig
|
||||||
@@ -75,26 +69,37 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: pip install numpy onnx colorama
|
run: pip install numpy onnx colorama
|
||||||
|
|
||||||
- name: Restore MLIR cache
|
- name: Restore Raptor Build Cache
|
||||||
uses: actions/cache/restore@v4
|
id: restore-raptor-build-cache
|
||||||
|
uses: ./.github/actions/restore-raptor-build-cache
|
||||||
with:
|
with:
|
||||||
path: onnx-mlir/llvm-project
|
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
|
||||||
key: mlir-${{ runner.os }}-${{ env.LLVM_COMMIT }}
|
restore-keys: |
|
||||||
fail-on-cache-miss: true
|
raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-
|
||||||
|
raptor-build-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-
|
||||||
|
|
||||||
- name: Build Raptor
|
- name: Build Raptor
|
||||||
|
id: build-raptor
|
||||||
run: |
|
run: |
|
||||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
|
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
|
||||||
mkdir -p build && cd build
|
cmake -S . -B build -G Ninja \
|
||||||
cmake .. -G Ninja \
|
|
||||||
-DCMAKE_BUILD_TYPE=Debug \
|
-DCMAKE_BUILD_TYPE=Debug \
|
||||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
-DMLIR_DIR=${MLIR_DIR}
|
-DMLIR_DIR=${MLIR_DIR} \
|
||||||
cmake --build .
|
-DCMAKE_EXE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
|
||||||
|
-DCMAKE_MODULE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}"
|
||||||
|
cmake --build build
|
||||||
|
|
||||||
|
- name: Save Raptor Build Cache
|
||||||
|
if: steps.build-raptor.outcome == 'success' && steps.restore-raptor-build-cache.outputs.cache-hit != 'true'
|
||||||
|
uses: ./.github/actions/save-raptor-build-cache
|
||||||
|
with:
|
||||||
|
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
|
||||||
|
|
||||||
- name: Run validation
|
- name: Run validation
|
||||||
run: |
|
run: |
|
||||||
python validate.py \
|
python validation/validate.py \
|
||||||
--raptor-path build/Debug/bin/onnx-mlir \
|
--raptor-path build/Debug/bin/onnx-mlir \
|
||||||
--onnx-include-dir onnx-mlir/include
|
--onnx-include-dir onnx-mlir/include
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
|
.claude
|
||||||
|
AGENTS.md
|
||||||
build
|
build
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -29,7 +29,7 @@ Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
|
|||||||
|
|
||||||
Moreover, if compiling with build type debug, it is also suggested to use
|
Moreover, if compiling with build type debug, it is also suggested to use
|
||||||
mold as linker (you will need to install it if you don't have it already)
|
mold as linker (you will need to install it if you don't have it already)
|
||||||
to reduce memory usage during linking. You can use it with:
|
to reduce memory usage during linking. You can use it by setting the options:
|
||||||
```
|
```
|
||||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
|
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
|
||||||
@@ -38,8 +38,16 @@ to reduce memory usage during linking. You can use it with:
|
|||||||
### Raptor
|
### Raptor
|
||||||
|
|
||||||
Use the following commands to build Raptor.
|
Use the following commands to build Raptor.
|
||||||
|
|
||||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
|
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
|
||||||
|
|
||||||
|
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
|
||||||
|
setting the options:
|
||||||
|
```
|
||||||
|
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||||
|
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
|
||||||
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
|
||||||
|
|||||||
@@ -530,6 +530,7 @@ where
|
|||||||
let r2_val = r2;
|
let r2_val = r2;
|
||||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||||
let rd_val = core.register(rd);
|
let rd_val = core.register(rd);
|
||||||
|
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||||
let load1 = loads[0];
|
let load1 = loads[0];
|
||||||
@@ -700,6 +701,7 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
|||||||
let local_memory = core.load::<u8>(r1_val, imm_len)?;
|
let local_memory = core.load::<u8>(r1_val, imm_len)?;
|
||||||
let tmp = local_memory[0].to_vec();
|
let tmp = local_memory[0].to_vec();
|
||||||
core.execute_store(rd_val, tmp.as_slice());
|
core.execute_store(rd_val, tmp.as_slice());
|
||||||
|
TRACER.lock().unwrap().post_lmv(cores, data);
|
||||||
Ok(InstructionStatus::Completed)
|
Ok(InstructionStatus::Completed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,10 @@ impl Instruction {
|
|||||||
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
|
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn dump(&self) {
|
||||||
|
eprintln!("\t{}", functor_to_name(self.functor as usize));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Instructions = Vec<Instruction>;
|
pub type Instructions = Vec<Instruction>;
|
||||||
|
|||||||
@@ -71,18 +71,27 @@ pub fn json_to_instruction(
|
|||||||
inst_builder,
|
inst_builder,
|
||||||
inst_data_builder,
|
inst_data_builder,
|
||||||
json,
|
json,
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! json_str {
|
macro_rules! json_str {
|
||||||
($json:ident , $value:literal) => {
|
($json:ident , $value:literal) => {
|
||||||
$json.get($value).context(concat![$value, " field not present"])?.as_str().context(concat![$value, " field not str"])?
|
$json
|
||||||
|
.get($value)
|
||||||
|
.context(concat![$value, " field not present"])?
|
||||||
|
.as_str()
|
||||||
|
.context(concat![$value, " field not str"])?
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! json_i64 {
|
macro_rules! json_i64 {
|
||||||
($json:ident , $value:literal) => {
|
($json:ident , $value:literal) => {
|
||||||
$json.get($value).context(concat![$value, " field not present"])?.as_i64().context(concat![$value, " field not i64"])?
|
$json
|
||||||
|
.get($value)
|
||||||
|
.context(concat![$value, " field not present"])?
|
||||||
|
.as_i64()
|
||||||
|
.context(concat![$value, " field not i64"])?
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +224,21 @@ fn json_to_vvsub(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vvsub", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,7 +270,21 @@ fn json_to_vvdmul(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vvdmul", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,7 +334,21 @@ fn json_to_vavg(
|
|||||||
inst_data_builder: &mut InstructionDataBuilder,
|
inst_data_builder: &mut InstructionDataBuilder,
|
||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
todo!("Not present in the compiler");
|
let json = json.as_object().expect("Not an object");
|
||||||
|
assert_eq!("vavg", json_str!(json, "op"));
|
||||||
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
|
let rs2 = json_i64!(json, "rs2") as i32;
|
||||||
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
|
inst_data_builder
|
||||||
|
.set_rd(rd)
|
||||||
|
.set_r1(rs1)
|
||||||
|
.set_r2(rs2)
|
||||||
|
.set_imm_len(len)
|
||||||
|
.set_offset_select(offset_select)
|
||||||
|
.set_offset_value(offset_value);
|
||||||
|
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,7 +400,7 @@ fn json_to_vsigm(
|
|||||||
json: &Value,
|
json: &Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let json = json.as_object().expect("Not an object");
|
let json = json.as_object().expect("Not an object");
|
||||||
assert_eq!("vsigmoid", json_str!(json, "op"));
|
assert_eq!("vsigm", json_str!(json, "op"));
|
||||||
let rd = json_i64!(json, "rd") as i32;
|
let rd = json_i64!(json, "rd") as i32;
|
||||||
let rs1 = json_i64!(json, "rs1") as i32;
|
let rs1 = json_i64!(json, "rs1") as i32;
|
||||||
let len = json_i64!(json, "len") as i32;
|
let len = json_i64!(json, "len") as i32;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
||||||
};
|
};
|
||||||
pub mod cpu;
|
pub mod cpu;
|
||||||
pub mod instruction_set;
|
pub mod instruction_set;
|
||||||
@@ -131,6 +131,16 @@ impl Executable {
|
|||||||
pub fn cpu_mut(&mut self) -> &mut CPU {
|
pub fn cpu_mut(&mut self) -> &mut CPU {
|
||||||
&mut self.cpu
|
&mut self.cpu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dump(&self) {
|
||||||
|
let core_instructions = &self.core_instructions;
|
||||||
|
for (i, core_instruction) in core_instructions.iter().enumerate() {
|
||||||
|
eprintln!("INST OF CORE {}:", i);
|
||||||
|
for inst in &core_instruction.instructions {
|
||||||
|
inst.dump();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
|
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
|
||||||
|
|||||||
@@ -1123,7 +1123,7 @@ impl Trace {
|
|||||||
if prefix == "Pre" {
|
if prefix == "Pre" {
|
||||||
writeln!(
|
writeln!(
|
||||||
file,
|
file,
|
||||||
"Inst: lvm {} {} {} {{ {} {} }}",
|
"Inst: lmv {} {} {} {{ {} {} }}",
|
||||||
rd, r1, imm_len, offset_select, offset_value
|
rd, r1, imm_len, offset_select, offset_value
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
@@ -1141,13 +1141,15 @@ impl Trace {
|
|||||||
);
|
);
|
||||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||||
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
|
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
|
||||||
let core_memory = core.load::<u8>(r1_val, imm_len).unwrap();
|
let core_memory = core
|
||||||
let global_memory = host.load::<u8>(rd_val, imm_len).unwrap();
|
.reserve_load(r1_val, imm_len).unwrap()
|
||||||
|
.reserve_load(rd_val, imm_len).unwrap()
|
||||||
|
.execute_load::<u8>().unwrap();
|
||||||
writeln!(file, "{} Memory:", prefix);
|
writeln!(file, "{} Memory:", prefix);
|
||||||
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
|
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
|
||||||
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
|
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
|
||||||
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
|
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
|
||||||
pretty_print::print_slice::<_,f32>(file, global_memory[0], 30);
|
pretty_print::print_slice::<_,f32>(file, core_memory[1], 30);
|
||||||
|
|
||||||
if prefix == "Post" {
|
if prefix == "Post" {
|
||||||
writeln!(file, "\n###############################################\n");
|
writeln!(file, "\n###############################################\n");
|
||||||
|
|||||||
Submodule onnx-mlir updated: 84cedd1d69...eb54c2afc4
@@ -10,38 +10,61 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
|||||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
||||||
add_subdirectory(Dialect)
|
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||||
add_subdirectory(Compiler)
|
|
||||||
add_subdirectory(Conversion)
|
|
||||||
add_subdirectory(Common)
|
|
||||||
|
|
||||||
add_onnx_mlir_library(OMPIMAccel
|
|
||||||
PimAccelerator.cpp
|
|
||||||
Transforms/PimBufferizationPass.cpp
|
|
||||||
Pass/CountInstructionPass.cpp
|
|
||||||
Pass/EmitPimJsonPass.cpp
|
|
||||||
Pass/MessagePass.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
|
||||||
|
|
||||||
INCLUDE_DIRS PUBLIC
|
|
||||||
${ONNX_MLIR_SRC_ROOT}/include
|
${ONNX_MLIR_SRC_ROOT}/include
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||||
${PIM_SRC_ROOT}
|
${PIM_SRC_ROOT}
|
||||||
${PIM_BIN_ROOT}
|
${PIM_BIN_ROOT}
|
||||||
${PIM_INCLUDE_PATH}
|
${PIM_INCLUDE_PATH}
|
||||||
|
)
|
||||||
|
|
||||||
|
set(PIM_COMPILER_INCLUDE_DIRS
|
||||||
|
${PIM_SRC_ROOT}
|
||||||
|
${PIM_BIN_ROOT}
|
||||||
|
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||||
|
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||||
|
)
|
||||||
|
|
||||||
|
set(PIM_ACCEL_INCLUDE_DIRS
|
||||||
|
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||||
|
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||||
|
)
|
||||||
|
|
||||||
|
set(PIM_GENERATED_INCLUDE_DIRS
|
||||||
|
${PIM_INCLUDE_PATH}
|
||||||
|
)
|
||||||
|
|
||||||
|
function(add_pim_library name)
|
||||||
|
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
add_subdirectory(Dialect)
|
||||||
|
add_subdirectory(Common)
|
||||||
|
add_subdirectory(Pass)
|
||||||
|
add_subdirectory(Compiler)
|
||||||
|
add_subdirectory(Conversion)
|
||||||
|
|
||||||
|
add_pim_library(OMPIMAccel
|
||||||
|
PimAccelerator.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
INCLUDE_DIRS PUBLIC
|
||||||
|
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
onnx
|
onnx
|
||||||
OMAccelerator
|
OMAccelerator
|
||||||
OMPimCompilerUtils
|
OMPimCompilerUtils
|
||||||
OMCompilerUtils
|
OMPimPasses
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
PimOps
|
PimOps
|
||||||
OMONNXToSpatial
|
OMONNXToSpatial
|
||||||
OMSpatialToGraphviz
|
OMSpatialToGraphviz
|
||||||
OMSpatialToPIM
|
OMSpatialToPim
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
|
OMPimBufferization
|
||||||
|
MLIRTensorInferTypeOpInterfaceImpl
|
||||||
)
|
)
|
||||||
@@ -1,19 +1,13 @@
|
|||||||
add_onnx_mlir_library(OMPIMCommon
|
add_pim_library(OMPimCommon
|
||||||
PIMCommon.cpp
|
PimCommon.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
INCLUDE_DIRS PUBLIC
|
INCLUDE_DIRS PUBLIC
|
||||||
${ONNX_MLIR_SRC_ROOT}/include
|
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
|
||||||
${PIM_SRC_ROOT}
|
|
||||||
${PIM_BIN_ROOT}
|
|
||||||
${PIM_INCLUDE_PATH}
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
onnx
|
onnx
|
||||||
OMPimCompilerUtils
|
|
||||||
SpatialOps
|
SpatialOps
|
||||||
PimOps
|
PimOps
|
||||||
)
|
)
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <filesystem>
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory) {
|
|
||||||
std::error_code errorCode;
|
|
||||||
std::filesystem::create_directories(directory, errorCode);
|
|
||||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
|
||||||
}
|
|
||||||
|
|
||||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
|
||||||
std::string dialectsDir = getOutputDir() + "/dialects";
|
|
||||||
createDirectory(dialectsDir);
|
|
||||||
|
|
||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
os << *moduleOp;
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
// channelNewOp should have two users: `op` and a
|
|
||||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
|
||||||
auto channelUsers = channelNewOp->getUsers();
|
|
||||||
auto usersIterator = channelUsers.begin();
|
|
||||||
auto firstUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator == channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"only one found.");
|
|
||||||
channelNewOp->dump();
|
|
||||||
op->dump();
|
|
||||||
channelNewOp->getParentOp()->dump();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto secondUser = *usersIterator;
|
|
||||||
usersIterator++;
|
|
||||||
if (usersIterator != channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"more than two found.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
Operation* notOpUser;
|
|
||||||
if (firstUser == op) {
|
|
||||||
notOpUser = secondUser;
|
|
||||||
}
|
|
||||||
else if (secondUser == op) {
|
|
||||||
notOpUser = firstUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
||||||
"and one of them must be me, but"
|
|
||||||
"none of them is actually me.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opIsReceive) {
|
|
||||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelSendOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
||||||
"me, the other is not a ChannelReceiveOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return notOpUser;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
std::string getOutputDir();
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory);
|
|
||||||
|
|
||||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation*>
|
|
||||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
302
src/PIM/Common/PimCommon.cpp
Normal file
302
src/PIM/Common/PimCommon.cpp
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::string getOutputDir() {
|
||||||
|
if (outputBaseName.empty() || outputBaseName == "-")
|
||||||
|
return {};
|
||||||
|
|
||||||
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||||
|
if (lastSlash == std::string::npos)
|
||||||
|
return ".";
|
||||||
|
return outputBaseName.substr(0, lastSlash);
|
||||||
|
}
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::filesystem::create_directories(directory, errorCode);
|
||||||
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/dialects";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
|
llvm::raw_os_ostream os(file);
|
||||||
|
os << *moduleOp;
|
||||||
|
os.flush();
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||||
|
if (entryPoints.size() > 1) {
|
||||||
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!entryPoints.empty()) {
|
||||||
|
auto entryPointAttr =
|
||||||
|
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||||
|
if (!entryPointAttr) {
|
||||||
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||||
|
if (!entryFunc) {
|
||||||
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||||
|
<< entryPointAttr.getLeafReference().getValue();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return entryFunc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||||
|
return mainGraphFunc;
|
||||||
|
|
||||||
|
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||||
|
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||||
|
if (!funcOp.isExternal())
|
||||||
|
nonExternalFuncs.push_back(funcOp);
|
||||||
|
if (nonExternalFuncs.size() == 1)
|
||||||
|
return nonExternalFuncs.front();
|
||||||
|
|
||||||
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||||
|
|
||||||
|
void markWeightAlways(Operation* op) {
|
||||||
|
assert(op && "expected valid op");
|
||||||
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (!moduleOp || !getGlobalOp)
|
||||||
|
return {};
|
||||||
|
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||||
|
|
||||||
|
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
if (!channelNewOp) {
|
||||||
|
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// channelNewOp should have two users: `op` and a
|
||||||
|
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||||
|
auto channelUsers = channelNewOp->getUsers();
|
||||||
|
auto usersIterator = channelUsers.begin();
|
||||||
|
auto firstUser = *usersIterator;
|
||||||
|
usersIterator++;
|
||||||
|
if (usersIterator == channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"only one found.");
|
||||||
|
channelNewOp->dump();
|
||||||
|
op->dump();
|
||||||
|
channelNewOp->getParentOp()->dump();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto secondUser = *usersIterator;
|
||||||
|
usersIterator++;
|
||||||
|
if (usersIterator != channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"more than two found.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Operation* notOpUser;
|
||||||
|
if (firstUser == op) {
|
||||||
|
notOpUser = secondUser;
|
||||||
|
}
|
||||||
|
else if (secondUser == op) {
|
||||||
|
notOpUser = firstUser;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"and one of them must be me, but"
|
||||||
|
"none of them is actually me.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opIsReceive) {
|
||||||
|
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||||
|
"me, the other is not a ChannelSendOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return notOpUser;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||||
|
"me, the other is not a ChannelReceiveOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return notOpUser;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||||
|
SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
indices[dim] = linearIndex / stride;
|
||||||
|
linearIndex %= stride;
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||||
|
linearIndex += index * stride;
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
numElements *= dim;
|
||||||
|
return numElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||||
|
ArrayRef<int64_t> offsets,
|
||||||
|
ArrayRef<int64_t> sizes,
|
||||||
|
ArrayRef<int64_t> strides) {
|
||||||
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||||
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstNonZeroOffset = std::find_if(
|
||||||
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return offset != 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||||
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||||
|
if (size > dimension - offset)
|
||||||
|
return false;
|
||||||
|
++firstNonZeroOffset;
|
||||||
|
|
||||||
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != sizesAndShape.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
|
||||||
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||||
|
if (!tiedOperand)
|
||||||
|
return failure();
|
||||||
|
value = tiedOperand->get();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||||
|
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||||
|
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||||
|
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||||
|
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||||
|
return failure();
|
||||||
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
56
src/PIM/Common/PimCommon.hpp
Normal file
56
src/PIM/Common/PimCommon.hpp
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct ResolvedContiguousAddress {
|
||||||
|
mlir::Value base;
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string getOutputDir();
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory);
|
||||||
|
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::Operation*>
|
||||||
|
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
|
|
||||||
public:
|
|
||||||
llvm::DenseMap<mlir::Value, T> map;
|
|
||||||
|
|
||||||
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
|
|
||||||
: ForwardingListener(&listener) {}
|
|
||||||
|
|
||||||
void notifyOperationErased(mlir::Operation* op) override {
|
|
||||||
for (mlir::Value result : op->getResults())
|
|
||||||
map.erase(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
void notifyBlockErased(mlir::Block* block) override {
|
|
||||||
for (mlir::BlockArgument arg : block->getArguments())
|
|
||||||
map.erase(arg);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
|
|
||||||
public:
|
|
||||||
llvm::DenseMap<mlir::Value, T> map;
|
|
||||||
|
|
||||||
NotErasableValueMap(mlir::OpBuilder::Listener listener)
|
|
||||||
: ForwardingListener(&listener) {}
|
|
||||||
|
|
||||||
void notifyOperationErased(mlir::Operation* op) override {
|
|
||||||
for (mlir::Value result : op->getResults())
|
|
||||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
|
|
||||||
}
|
|
||||||
|
|
||||||
void notifyBlockErased(mlir::Block* block) override {
|
|
||||||
for (mlir::BlockArgument arg : block->getArguments())
|
|
||||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@@ -1,44 +1,36 @@
|
|||||||
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
|
add_pim_library(OMPimCompilerOptions
|
||||||
|
|
||||||
add_onnx_mlir_library(OMPimCompilerOptions
|
|
||||||
PimCompilerOptions.cpp
|
PimCompilerOptions.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
INCLUDE_DIRS PRIVATE
|
INCLUDE_DIRS PRIVATE
|
||||||
${PIM_SRC_ROOT}
|
${PIM_COMPILER_INCLUDE_DIRS}
|
||||||
${PIM_BIN_ROOT}
|
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
|
||||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
${OMLibs}
|
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
${PIM_ACCEL_INCLUDE_DIRS}
|
||||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
add_onnx_mlir_library(OMPimCompilerUtils
|
add_pim_library(OMPimCompilerUtils
|
||||||
PimCompilerUtils.cpp
|
PimCompilerUtils.cpp
|
||||||
PimCodeGen.cpp
|
PimCodeGen.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
INCLUDE_DIRS PRIVATE
|
INCLUDE_DIRS PRIVATE
|
||||||
${PIM_SRC_ROOT}
|
${PIM_COMPILER_INCLUDE_DIRS}
|
||||||
${PIM_BIN_ROOT}
|
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
|
||||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
${OMLibs}
|
|
||||||
OMCompilerUtils
|
|
||||||
OMPimCompilerOptions
|
OMPimCompilerOptions
|
||||||
|
OMPimCommon
|
||||||
|
OMPimBufferization
|
||||||
|
OMPimPasses
|
||||||
|
OMONNXToSpatial
|
||||||
|
OMSpatialToPim
|
||||||
OMCompilerPasses
|
OMCompilerPasses
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
${PIM_ACCEL_INCLUDE_DIRS}
|
||||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,17 +13,18 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Compiler/CompilerPasses.hpp"
|
#include "src/Compiler/CompilerPasses.hpp"
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
|
||||||
|
|
||||||
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
MemEntry* PimMemory::gatherMemEntry(Value value) {
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||||
auto type = cast<ShapedType>(value.getType());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
assert("Only static shape is supported" && type.hasStaticShape());
|
assert("Only static shape is supported" && type.hasStaticShape());
|
||||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
@@ -31,7 +32,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
|
|||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
|
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||||
memEntry.address = firstAvailableAddress;
|
memEntry.address = firstAvailableAddress;
|
||||||
firstAvailableAddress += memEntry.size;
|
firstAvailableAddress += memEntry.size;
|
||||||
// Alignment
|
// Alignment
|
||||||
@@ -47,8 +48,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
||||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!getGlobalOp->hasAttr("weightAlways")) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
auto iter = globalConstants.find(globalMemrefOp);
|
auto iter = globalConstants.find(globalMemrefOp);
|
||||||
if (iter == globalConstants.end())
|
if (iter == globalConstants.end())
|
||||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
||||||
@@ -59,7 +60,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (Value arg : funcOp.getArguments())
|
for (mlir::Value arg : funcOp.getArguments())
|
||||||
gatherMemEntry(arg);
|
gatherMemEntry(arg);
|
||||||
|
|
||||||
allocateCore(funcOp);
|
allocateCore(funcOp);
|
||||||
@@ -73,40 +74,44 @@ void PimMemory::allocateCore(Operation* op) {
|
|||||||
allocateMemoryForValue(value, memEntry);
|
allocateMemoryForValue(value, memEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(Value value) const {
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||||
auto iter = globalMemEntriesMap.find(value);
|
auto iter = globalMemEntriesMap.find(value);
|
||||||
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
|
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||||
while (true) {
|
auto resolvedAddress = resolveContiguousAddress(value);
|
||||||
auto definingOp = value.getDefiningOp();
|
if (failed(resolvedAddress)) {
|
||||||
if (!definingOp)
|
errs() << "Failed to resolve contiguous address for value: ";
|
||||||
break;
|
value.print(errs());
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
errs() << "\n";
|
||||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
if (auto* definingOp = value.getDefiningOp()) {
|
||||||
if (!tiedOperand)
|
errs() << "Defining op:\n";
|
||||||
break;
|
definingOp->print(errs());
|
||||||
value = tiedOperand->get();
|
errs() << "\n";
|
||||||
}
|
}
|
||||||
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
llvm_unreachable("Failed to resolve contiguous address");
|
||||||
auto source = subviewDefiningOp.getSource();
|
|
||||||
auto srcShape = source.getType().getShape();
|
|
||||||
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
|
|
||||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
|
||||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
|
||||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
|
||||||
value = source;
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
break;
|
auto iter = memEntriesMap.find(resolvedAddress->base);
|
||||||
|
if (iter == memEntriesMap.end()) {
|
||||||
|
errs() << "Missing mem entry for value: ";
|
||||||
|
resolvedAddress->base.print(errs());
|
||||||
|
errs() << "\n";
|
||||||
|
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
||||||
|
errs() << "Defining op:\n";
|
||||||
|
definingOp->print(errs());
|
||||||
|
errs() << "\n";
|
||||||
}
|
}
|
||||||
return memEntriesMap.at(value).address;
|
llvm_unreachable("Missing mem entry");
|
||||||
|
}
|
||||||
|
|
||||||
|
return iter->second.address + resolvedAddress->byteOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
json::Object PimCodeGen::createEmptyOffset() {
|
json::Object PimCodeGen::createEmptyOffset() {
|
||||||
@@ -144,15 +149,20 @@ void PimCodeGen::setupRdRs1Rs2(
|
|||||||
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::emitMemCopyOp(
|
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||||
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const {
|
size_t rdAddr,
|
||||||
|
size_t rdOffset,
|
||||||
|
size_t rs1Addr,
|
||||||
|
size_t rs1Offset,
|
||||||
|
size_t size,
|
||||||
|
StringRef sizeFieldName) const {
|
||||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = opName;
|
json["op"] = opName;
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["size"] = size;
|
json[sizeFieldName] = size;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
@@ -184,47 +194,59 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
|||||||
|
|
||||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
|
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
|
||||||
emitMemCopyOp("ld",
|
emitMemCopyOp("ld",
|
||||||
memory.getValueAddress(loadOp.getDeviceDst()),
|
memory.getValueAddress(loadOp.getDeviceTarget()),
|
||||||
loadOp.getDeviceDstOffset(),
|
loadOp.getDeviceTargetOffset(),
|
||||||
memory.getValueAddress(loadOp.getHostSrc()),
|
memory.getValueAddress(loadOp.getHostSource()),
|
||||||
loadOp.getHostSrcOffset(),
|
loadOp.getHostSourceOffset(),
|
||||||
loadOp.getSize());
|
loadOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
||||||
emitMemCopyOp("st",
|
emitMemCopyOp("st",
|
||||||
memory.getValueAddress(storeOp.getHostDst()),
|
memory.getValueAddress(storeOp.getHostTarget()),
|
||||||
storeOp.getHostDstOffset(),
|
storeOp.getHostTargetOffset(),
|
||||||
memory.getValueAddress(storeOp.getDeviceSrc()),
|
memory.getValueAddress(storeOp.getDeviceSource()),
|
||||||
storeOp.getDeviceSrcOffset(),
|
storeOp.getDeviceSourceOffset(),
|
||||||
storeOp.getSize());
|
storeOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
||||||
|
emitMemCopyOp("lmv",
|
||||||
|
memory.getValueAddress(lmvOp.getTarget()),
|
||||||
|
lmvOp.getTargetOffset(),
|
||||||
|
memory.getValueAddress(lmvOp.getSource()),
|
||||||
|
lmvOp.getSourceOffset(),
|
||||||
|
lmvOp.getSize(),
|
||||||
|
"len");
|
||||||
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
||||||
emitCommunicationOp(
|
emitCommunicationOp(
|
||||||
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
|
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
|
||||||
emitCommunicationOp("send", memory.getValueAddress(sendOp.getSrc()), sendOp.getTargetCoreId(), sendOp.getSize());
|
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
|
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
|
||||||
emitMvmOp(
|
emitMvmOp(
|
||||||
mvmId, memory.getValueAddress(mvmLikeOp.getOutBuf()), 0, memory.getValueAddress(mvmLikeOp.getVectorInput()), 0);
|
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
|
||||||
|
|
||||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||||
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
auto aAddr = memory.getValueAddress(vaddOp.getA());
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
auto bAddr = memory.getValueAddress(vaddOp.getB());
|
}
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
|
||||||
|
|
||||||
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
|
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||||
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
|
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
|
||||||
|
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
|
||||||
|
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
|
||||||
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvadd";
|
json["op"] = "vvadd";
|
||||||
@@ -232,15 +254,47 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
json["len"] = totalBytes;
|
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vmaxOp.getA());
|
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
|
||||||
auto bAddr = memory.getValueAddress(vmaxOp.getB());
|
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
|
||||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvsub";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||||
|
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
|
||||||
|
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
|
||||||
|
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
|
||||||
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvmul";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||||
|
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
|
||||||
|
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
|
||||||
|
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
|
||||||
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vvmax";
|
json["op"] = "vvmax";
|
||||||
@@ -248,79 +302,125 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
|||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["rs2"] = 2;
|
json["rs2"] = 2;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||||
|
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
|
||||||
|
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
|
||||||
|
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
|
||||||
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vvdmul";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["rs2"] = 2;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||||
|
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
|
||||||
|
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
|
||||||
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vavg";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vavgOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(vreluOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
|
||||||
auto aAddr = memory.getValueAddress(vreluOp.getA());
|
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
|
||||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = "vrelu";
|
json["op"] = "vrelu";
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["rs1"] = 1;
|
json["rs1"] = 1;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vreluOp.getInput());
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const {
|
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||||
auto outBufAddr = memory.getValueAddress(applyFiltersOp.getOutBuf());
|
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
|
||||||
auto inBufAddr = memory.getValueAddress(applyFiltersOp.getInput());
|
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
|
||||||
auto accumBufAddr = memory.getValueAddress(applyFiltersOp.getAccumBuf());
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
auto weightIndices = applyFiltersOp.getWeightIndices();
|
json::Object json;
|
||||||
|
json["op"] = "vtanh";
|
||||||
auto inputType = cast<MemRefType>(applyFiltersOp.getInput().getType());
|
json["rd"] = 0;
|
||||||
auto outputType = cast<MemRefType>(applyFiltersOp.getOutBuf().getType());
|
json["rs1"] = 1;
|
||||||
auto inShape = inputType.getShape();
|
json["offset"] = createEmptyOffset();
|
||||||
auto outShape = outputType.getShape();
|
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
size_t inChannels = inShape[1];
|
|
||||||
size_t outChannels = outShape[1];
|
|
||||||
size_t dimX = inShape.size() > 2 ? inShape[2] : 1;
|
|
||||||
size_t dimY = inShape.size() > 3 ? inShape[3] : 1;
|
|
||||||
|
|
||||||
for (size_t outY = 0; outY < dimY; outY++) {
|
|
||||||
for (size_t outX = 0; outX < dimX; outX++) {
|
|
||||||
|
|
||||||
size_t weightIndex = 0;
|
|
||||||
for (Attribute weight : weightIndices) {
|
|
||||||
// --- STEP 1: Perform MVMUL operation ---
|
|
||||||
auto weightId = cast<IntegerAttr>(weight).getInt();
|
|
||||||
size_t xKer = cast<IntegerAttr>(applyFiltersOp.getXKernelPositions()[weightIndex]).getInt();
|
|
||||||
size_t yKer = cast<IntegerAttr>(applyFiltersOp.getYKernelPositions()[weightIndex]).getInt();
|
|
||||||
weightIndex++;
|
|
||||||
|
|
||||||
if (outX + xKer >= dimX || outY + yKer >= dimY)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
size_t outputOffset = (outY * dimX + outX) * 32 * outChannels;
|
|
||||||
size_t inputOffset = ((outY + yKer) * dimX + (outX + xKer)) * 32 * inChannels;
|
|
||||||
|
|
||||||
bool isFirstWeight = (weightIndices[0] == weight);
|
|
||||||
|
|
||||||
// For the first weight, store directly in output buffer; otherwise use accumulator.
|
|
||||||
size_t rdAddr = isFirstWeight ? outBufAddr : accumBufAddr;
|
|
||||||
size_t rdOffset = isFirstWeight ? outputOffset : 0;
|
|
||||||
emitMvmOp(weightId, rdAddr, rdOffset, inBufAddr, inputOffset);
|
|
||||||
|
|
||||||
// --- STEP 2: Perform VADD operation (skip for first weight) ---
|
|
||||||
if (isFirstWeight)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Sum accumulator with output buffer, store result in output buffer.
|
|
||||||
setupRdRs1Rs2(outBufAddr, outputOffset, accumBufAddr, 0, outBufAddr, outputOffset);
|
|
||||||
|
|
||||||
json::Object vaddJson;
|
|
||||||
vaddJson["op"] = "vvadd";
|
|
||||||
vaddJson["rd"] = 0;
|
|
||||||
vaddJson["rs1"] = 1;
|
|
||||||
vaddJson["rs2"] = 2;
|
|
||||||
vaddJson["offset"] = createEmptyOffset();
|
|
||||||
emitInstruction(std::move(vaddJson));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||||
|
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
|
||||||
|
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
|
||||||
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||||
|
|
||||||
|
json::Object json;
|
||||||
|
json["op"] = "vsigm";
|
||||||
|
json["rd"] = 0;
|
||||||
|
json["rs1"] = 1;
|
||||||
|
json["offset"] = createEmptyOffset();
|
||||||
|
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
|
||||||
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||||
|
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
|
||||||
|
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
|
||||||
|
|
||||||
|
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||||
|
auto srcShape = srcType.getShape();
|
||||||
|
size_t rank = srcShape.size();
|
||||||
|
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||||
|
size_t totalElements = srcType.getNumElements();
|
||||||
|
|
||||||
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||||
|
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
|
||||||
|
[](auto attr) -> int64_t { return attr.getInt(); });
|
||||||
|
|
||||||
|
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||||
|
SmallVector<int64_t> dstShape(rank);
|
||||||
|
for (size_t i = 0; i < rank; i++)
|
||||||
|
dstShape[i] = srcShape[perm[i]];
|
||||||
|
|
||||||
|
// Row-major strides for source and destination
|
||||||
|
SmallVector<size_t> srcStrides(rank, 1);
|
||||||
|
SmallVector<size_t> dstStrides(rank, 1);
|
||||||
|
for (int64_t i = rank - 2; i >= 0; i--) {
|
||||||
|
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
|
||||||
|
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit element-by-element copy with transposed addressing
|
||||||
|
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||||
|
// Decompose flat source index into multi-dimensional index
|
||||||
|
SmallVector<size_t> srcIdx(rank);
|
||||||
|
size_t remaining = srcFlat;
|
||||||
|
for (size_t d = 0; d < rank; d++) {
|
||||||
|
srcIdx[d] = remaining / srcStrides[d];
|
||||||
|
remaining %= srcStrides[d];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
|
||||||
|
size_t dstFlat = 0;
|
||||||
|
for (size_t d = 0; d < rank; d++)
|
||||||
|
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||||
|
|
||||||
|
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,7 +443,6 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
|
|
||||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||||
std::error_code errorCode;
|
std::error_code errorCode;
|
||||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||||
@@ -355,9 +454,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
|||||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||||
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (getGlobalOp->hasAttr("weightAlways"))
|
if (hasWeightAlways(getGlobalOp))
|
||||||
return;
|
return;
|
||||||
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (!globalOp)
|
if (!globalOp)
|
||||||
return;
|
return;
|
||||||
auto initialValue = globalOp.getInitialValue();
|
auto initialValue = globalOp.getInitialValue();
|
||||||
@@ -393,34 +492,43 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
|||||||
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||||
size_t processedOperations = 0;
|
size_t processedOperations = 0;
|
||||||
for (auto& op : coreOp.getBody().front()) {
|
for (auto& op : coreOp.getBody().front()) {
|
||||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op))
|
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||||
coreCodeGen.codeGenLoadOp(loadOp);
|
coreCodeGen.codeGenLoadOp(loadOp);
|
||||||
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
||||||
coreCodeGen.codeGenStoreOp(storeOp);
|
coreCodeGen.codeGenStoreOp(storeOp);
|
||||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
coreCodeGen.codeGenLmvOp(lmvOp);
|
||||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
|
||||||
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
|
|
||||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
|
||||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
|
||||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
|
||||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
|
||||||
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
|
||||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
|
||||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
|
||||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||||
coreCodeGen.codeGenReceiveOp(receiveOp);
|
coreCodeGen.codeGenReceiveOp(receiveOp);
|
||||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||||
coreCodeGen.codeGenSendOp(sendOp);
|
coreCodeGen.codeGenSendOp(sendOp);
|
||||||
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||||
// TODO: Implement somehow?
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
||||||
op.emitWarning("Operation is not yet supported in code generation");
|
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||||
continue;
|
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
||||||
}
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
|
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||||
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
|
coreCodeGen.codeGenVVAddOp(vvaddOp);
|
||||||
|
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||||
|
coreCodeGen.codeGenVVSubOp(vvsubOp);
|
||||||
|
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||||
|
coreCodeGen.codeGenVVMulOp(vvmulOp);
|
||||||
|
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||||
|
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
|
||||||
|
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||||
|
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
|
||||||
|
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||||
|
coreCodeGen.codeGenVAvgOp(vavgOp);
|
||||||
|
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||||
|
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||||
|
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||||
|
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||||
|
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||||
|
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
op.dump();
|
||||||
@@ -450,7 +558,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (!globalOp) {
|
if (!globalOp) {
|
||||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
|
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
|
||||||
weightIndex++;
|
weightIndex++;
|
||||||
@@ -539,7 +647,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
|||||||
|
|
||||||
json::Array outputsAddresses;
|
json::Array outputsAddresses;
|
||||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||||
for (Value output : returnOp.getOperands())
|
for (mlir::Value output : returnOp.getOperands())
|
||||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||||
|
|
||||||
@@ -564,9 +672,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto funcOps = moduleOp.getOps<func::FuncOp>();
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
assert(!funcOps.empty() && "No function found in the module");
|
if (failed(entryFunc))
|
||||||
auto funcOp = *funcOps.begin();
|
return CompilerFailure;
|
||||||
|
auto funcOp = *entryFunc;
|
||||||
|
|
||||||
PimAcceleratorMemory memory;
|
PimAcceleratorMemory memory;
|
||||||
memory.hostMem.allocateHost(moduleOp, funcOp);
|
memory.hostMem.allocateHost(moduleOp, funcOp);
|
||||||
|
|||||||
@@ -3,69 +3,62 @@
|
|||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
#include "Common/ValueMap.hpp"
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
using namespace llvm;
|
|
||||||
using namespace mlir;
|
|
||||||
using Value = mlir::Value;
|
|
||||||
using Type = mlir::Type;
|
|
||||||
using FunctionType = mlir::FunctionType;
|
|
||||||
|
|
||||||
struct MemEntry {
|
struct MemEntry {
|
||||||
size_t address;
|
size_t address;
|
||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimMemory {
|
class PimMemory {
|
||||||
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
|
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||||
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||||
|
|
||||||
size_t maxSize = 0; // 0 for unbounded memory
|
size_t maxSize = 0; // 0 for unbounded memory
|
||||||
size_t startAddress = 0;
|
size_t startAddress = 0;
|
||||||
size_t minAlignment = 4;
|
size_t minAlignment = 4;
|
||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
MemEntry* gatherMemEntry(Value value);
|
MemEntry* gatherMemEntry(mlir::Value value);
|
||||||
void allocateMemoryForValue(Value value, MemEntry& memEntry);
|
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
|
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
|
||||||
: globalMemEntriesMap(globalMemEntriesMap) {}
|
: globalMemEntriesMap(globalMemEntriesMap) {}
|
||||||
|
|
||||||
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
|
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||||
void allocateCore(Operation* op);
|
void allocateCore(mlir::Operation* op);
|
||||||
|
|
||||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||||
MemEntry getMemEntry(Value value) const;
|
MemEntry getMemEntry(mlir::Value value) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimAcceleratorMemory {
|
class PimAcceleratorMemory {
|
||||||
public:
|
public:
|
||||||
SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
|
||||||
PimMemory hostMem;
|
PimMemory hostMem;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallDenseMap<size_t, PimMemory> deviceMem;
|
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
: hostMem(memEntriesMap) {}
|
: hostMem(memEntriesMap) {}
|
||||||
|
|
||||||
PimMemory getOrCreateDeviceMem(size_t id);
|
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||||
|
|
||||||
size_t getValueAddress(Value value) const;
|
size_t getValueAddress(mlir::Value value) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreFileStream;
|
||||||
|
|
||||||
static json::Object createEmptyOffset();
|
static llvm::json::Object createEmptyOffset();
|
||||||
void emitInstruction(json::Object instruction) const;
|
void emitInstruction(llvm::json::Object instruction) const;
|
||||||
|
|
||||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||||
@@ -73,17 +66,23 @@ class PimCodeGen {
|
|||||||
void setupRdRs1Rs2(
|
void setupRdRs1Rs2(
|
||||||
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
|
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
|
||||||
|
|
||||||
void
|
void emitMemCopyOp(mlir::StringRef opName,
|
||||||
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const;
|
size_t rdAddr,
|
||||||
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
size_t rdOffset,
|
||||||
|
size_t rs1Addr,
|
||||||
|
size_t rs1Offset,
|
||||||
|
size_t size,
|
||||||
|
mlir::StringRef sizeFieldName = "size") const;
|
||||||
|
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson)
|
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||||
: memory(memory), coreFileStream(coreJson) {}
|
: memory(memory), coreFileStream(coreJson) {}
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
|
||||||
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
|
||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp) const;
|
void codeGenSendOp(pim::PimSendOp sendOp) const;
|
||||||
@@ -91,12 +90,18 @@ public:
|
|||||||
template <typename MVMTy>
|
template <typename MVMTy>
|
||||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
||||||
|
|
||||||
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
|
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
|
||||||
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
|
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
|
||||||
|
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
|
||||||
|
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
|
||||||
|
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
|
||||||
|
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
|
||||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||||
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||||
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName);
|
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
|||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
extern llvm::cl::opt<bool> exportCrossbarWeights;
|
|
||||||
|
|
||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerPasses.hpp"
|
#include "src/Compiler/CompilerPasses.hpp"
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerUtils"
|
#define DEBUG_TYPE "PimCompilerUtils"
|
||||||
|
|
||||||
@@ -34,18 +33,23 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createSpatialToPIMPass());
|
pm.addPass(createSpatialToPimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||||
pm.addPass(createBufferizePimPass());
|
pm.addPass(createPimBufferizationPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim bufferized"));
|
pm.addPass(createMessagePass("Pim bufferized"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
|
pm.addPass(createPimConstantFoldingPass());
|
||||||
|
pm.addPass(createMessagePass("Pim constants folded"));
|
||||||
|
pm.addPass(createPimMaterializeConstantsPass());
|
||||||
|
pm.addPass(createPimVerificationPass());
|
||||||
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
add_subdirectory(ONNXToSpatial)
|
add_subdirectory(ONNXToSpatial)
|
||||||
add_subdirectory(SpatialToGraphviz)
|
add_subdirectory(SpatialToGraphviz)
|
||||||
add_subdirectory(SpatialToPIM)
|
add_subdirectory(SpatialToPim)
|
||||||
@@ -2,33 +2,30 @@ set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
|
|||||||
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_onnx_mlir_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
Math/Gemm.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Math/Conv.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
Math/ExperimentalConv.cpp
|
Patterns/Math/MatMul.cpp
|
||||||
Math/ExperimentalGemm.cpp
|
Patterns/NN/Pool.cpp
|
||||||
NN/Pooling.cpp
|
Patterns/NN/Relu.cpp
|
||||||
NN/ExperimentalPooling.cpp
|
Patterns/Tensor/Concat.cpp
|
||||||
NN/ReduceMean.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Tensor/ONNXConcatToTensorConcat.cpp
|
|
||||||
Tensor/RemoveUnusedHelperOps.cpp
|
|
||||||
Utils/SpatialReducer.cpp
|
|
||||||
Utils/WeightSubdivider.cpp
|
|
||||||
Utils/AnnotateReplication.cpp
|
|
||||||
ONNXToSpatialPass.hpp
|
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
ONNXToSpatialCommon.cpp
|
Common.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
ONNXToSpatialIncGen
|
ONNXToSpatialIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRTosaDialect
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
OMPimCompilerOptions
|
OMPimCompilerOptions
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
${PIM_INCLUDE_PATH}
|
${PIM_GENERATED_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
|
|||||||
137
src/PIM/Conversion/ONNXToSpatial/Common.cpp
Normal file
137
src/PIM/Conversion/ONNXToSpatial/Common.cpp
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Location.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
SmallVector<Value> sliceTensor(
|
||||||
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||||
|
assert("Invalid axis" && axis < shape.size());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||||
|
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(shape.size());
|
||||||
|
for (const auto size : shape)
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(size));
|
||||||
|
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||||
|
|
||||||
|
long length = shape[axis];
|
||||||
|
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
||||||
|
SmallVector<Value> slices;
|
||||||
|
slices.reserve(numSlices);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < numSlices; i++) {
|
||||||
|
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||||
|
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||||
|
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||||
|
|
||||||
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||||
|
slices.push_back(slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value>
|
||||||
|
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||||
|
assert("Not a vector" && isVectorShape(shape));
|
||||||
|
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||||
|
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseMap<CoreId, SmallVector<Value>>
|
||||||
|
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||||
|
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||||
|
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||||
|
size_t coreId = sliceId / crossbarCountInCore;
|
||||||
|
slicesPerCore[coreId].push_back(slices[sliceId]);
|
||||||
|
}
|
||||||
|
return slicesPerCore;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||||
|
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
||||||
|
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
||||||
|
|
||||||
|
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
||||||
|
|
||||||
|
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
||||||
|
size_t numHSlices = hSlices.size();
|
||||||
|
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
||||||
|
Value hSlice = hSlices[hSliceId];
|
||||||
|
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
||||||
|
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
||||||
|
size_t coreId = vSliceId / crossbarCountInCore;
|
||||||
|
Value vSlice = vSlices[vSliceId];
|
||||||
|
tiles[hSliceId][coreId].push_back(vSlice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tiles;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor::SplatOp
|
||||||
|
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||||
|
Type elementType = oldType.getElementType();
|
||||||
|
int64_t shape[2] = {1, length};
|
||||||
|
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||||
|
|
||||||
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
|
SmallVector<Value> index(oldType.getRank(), zero);
|
||||||
|
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||||
|
|
||||||
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||||
|
if (tensors.size() == 1)
|
||||||
|
return tensors[0];
|
||||||
|
|
||||||
|
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||||
|
SmallVector<Value> tensors2;
|
||||||
|
tensors2.reserve(tensors.size() / 2);
|
||||||
|
|
||||||
|
auto* currTensors = &tensors1;
|
||||||
|
auto* nextTensors = &tensors2;
|
||||||
|
while (currTensors->size() > 1) {
|
||||||
|
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||||
|
Value a = (*currTensors)[i];
|
||||||
|
Value b = (*currTensors)[i + 1];
|
||||||
|
rewriter.setInsertionPointAfterValue(b);
|
||||||
|
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||||
|
nextTensors->push_back(addedValue);
|
||||||
|
}
|
||||||
|
if (currTensors->size() % 2 == 1)
|
||||||
|
nextTensors->push_back(currTensors->back());
|
||||||
|
std::swap(currTensors, nextTensors);
|
||||||
|
nextTensors->clear();
|
||||||
|
}
|
||||||
|
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||||
|
return (*currTensors)[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace onnx_mlir
|
||||||
167
src/PIM/Conversion/ONNXToSpatial/Common.hpp
Normal file
167
src/PIM/Conversion/ONNXToSpatial/Common.hpp
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
#define DEFINE_MAP_OP(opname) opname,
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getImageN(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ShapedType>
|
||||||
|
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||||
|
return shapedType.getDimSize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
using HSliceId = size_t;
|
||||||
|
using CoreId = size_t;
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr C ceilIntegerDivide(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return 1 + (ac - 1) / bc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||||
|
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||||
|
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||||
|
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||||
|
C ac = static_cast<C>(a);
|
||||||
|
C bc = static_cast<C>(b);
|
||||||
|
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||||
|
return shape.size() == 2 && shape[1] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||||
|
assert(isVectorShape(shape));
|
||||||
|
return shape[0] != 1 ? shape[0] : shape[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline auto getTensorShape(mlir::Value tensor) {
|
||||||
|
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
|
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <size_t NumInputs, typename BodyFn>
|
||||||
|
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
|
size_t axis,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
|
int64_t sliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
|
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
|
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||||
|
tileMatrix(mlir::Value& matrixToTile,
|
||||||
|
int64_t hSliceSize,
|
||||||
|
int64_t vSliceSize,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location& loc);
|
||||||
|
|
||||||
|
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||||
|
int64_t length,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||||
|
|
||||||
|
}; // namespace onnx_mlir
|
||||||
@@ -1,583 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <memory>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
// NOTE:
|
|
||||||
// This might be useful to re-implement this considering for loops.
|
|
||||||
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A momentary representation of a core, to be used within the tiling of
|
|
||||||
* a convolution operation.
|
|
||||||
*/
|
|
||||||
class Core {
|
|
||||||
public:
|
|
||||||
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
|
|
||||||
: coreId(coreId), rewriter(rewriter) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add a MVM operation to the core.
|
|
||||||
*
|
|
||||||
* @param inputTile The input tile to the MVM operation.
|
|
||||||
* @param xbarIndex The index of the crossbar weight to use.
|
|
||||||
* @param outputTileId The id of the output tile.
|
|
||||||
* @param mvmOutType The result's shape.
|
|
||||||
* @return Value The result of the MVM operation.
|
|
||||||
*/
|
|
||||||
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
|
||||||
// Use the inputTile as the reference location for the MVM operation.
|
|
||||||
Location loc = inputTile.getLoc();
|
|
||||||
|
|
||||||
// Move the insertion point to the end of the block.
|
|
||||||
rewriter.setInsertionPointToEnd(block.get());
|
|
||||||
|
|
||||||
// Add the inputTile to the block arguments, and to the operands.
|
|
||||||
Value operand = operandMap.lookupOrNull(inputTile);
|
|
||||||
if (not operand) {
|
|
||||||
operand = block->addArgument(inputTile.getType(), loc);
|
|
||||||
operands.push_back(inputTile);
|
|
||||||
operandMap.map(inputTile, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
|
|
||||||
// is correct.
|
|
||||||
|
|
||||||
// Construct the MVM operation
|
|
||||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
|
|
||||||
|
|
||||||
// Since we are within the same core and no computation can happen in
|
|
||||||
// paralllel, we can just apply a linear reduction in case we have multiple
|
|
||||||
// MVM operations for the same outputTile.
|
|
||||||
auto lastMVM = outputTileToMVM.find(outputTileId);
|
|
||||||
|
|
||||||
// If an entry for this outputTile already exists, apply reduction.
|
|
||||||
if (lastMVM != outputTileToMVM.end()) {
|
|
||||||
// MVM results should have the same type for reduction.
|
|
||||||
assert(lastMVM->second.getType() == result.getType());
|
|
||||||
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputTileToMVM[outputTileId] = result;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Mark a result as remappable, and return a shared pointer to it.
|
|
||||||
*
|
|
||||||
* This function marks a result as remappable, and returns a shared pointer to
|
|
||||||
* it. We need to keep track of these values to generate the YieldOp at a
|
|
||||||
* later stage.
|
|
||||||
*
|
|
||||||
* @param result A result to track, for later remapping.
|
|
||||||
* @return shared_ptr<Value> A shared pointer to the result.
|
|
||||||
*/
|
|
||||||
shared_ptr<Value> makeResultRemappable(Value result) {
|
|
||||||
// Verify that the result is present in the block.
|
|
||||||
assert(result.getDefiningOp()->getBlock() == block.get());
|
|
||||||
|
|
||||||
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
|
|
||||||
|
|
||||||
resultsToRemap.push_back(remappableResult);
|
|
||||||
results.push_back(result);
|
|
||||||
|
|
||||||
return remappableResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add a remappable operand to the core, to merge partial results
|
|
||||||
* inter-core.
|
|
||||||
*
|
|
||||||
* @param remappableOperand The operand to add.
|
|
||||||
* @return Value The block argument representing the operand.
|
|
||||||
*/
|
|
||||||
Value addRemappableOperand(std::shared_ptr<Value> operand) {
|
|
||||||
// Check that the operand is not already there.
|
|
||||||
assert(not operandMap.contains(*operand));
|
|
||||||
|
|
||||||
Value argument = block->addArgument(operand->getType(), operand->getLoc());
|
|
||||||
remappableOperands.push_back(operand);
|
|
||||||
return argument;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
|
|
||||||
*
|
|
||||||
* @param loc The location of the operation.
|
|
||||||
* @return spatial::SpatWeightedCompute
|
|
||||||
*/
|
|
||||||
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
|
|
||||||
// Get the shape of the results.
|
|
||||||
SmallVector<Type> resultTypes;
|
|
||||||
for (const auto& value : results)
|
|
||||||
resultTypes.push_back(value.getType());
|
|
||||||
|
|
||||||
// Create the WComputeOp, with non-remappable operands only.
|
|
||||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
|
|
||||||
|
|
||||||
// Add the body to the WComputeOp.
|
|
||||||
Block* releasedBlock = block.release();
|
|
||||||
wcomputeOp.getBody().push_back(releasedBlock);
|
|
||||||
|
|
||||||
// Add the `yieldOp` at the end, with the results.
|
|
||||||
rewriter.setInsertionPointToEnd(releasedBlock);
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(loc, results);
|
|
||||||
|
|
||||||
return wcomputeOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Remap the results to the WComputeOp results.
|
|
||||||
*/
|
|
||||||
void remapResults() {
|
|
||||||
// Remap all the results to the WComputeOp results.
|
|
||||||
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
|
|
||||||
for (size_t i = 0; i < resultsToRemap.size(); i++)
|
|
||||||
*resultsToRemap[i] = wcomputeOp.getResult(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
void addRemappedOperands() {
|
|
||||||
// Insert the remappableOperands (which were remapped in
|
|
||||||
// `addRemappableOperand` of another Core)
|
|
||||||
for (auto remappedValue : remappableOperands)
|
|
||||||
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
|
|
||||||
|
|
||||||
// Update the wcomputeOp operandSegmentSize
|
|
||||||
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t addXbarWeight(Value weight) {
|
|
||||||
assert(!isXbarsFull());
|
|
||||||
xbarWeights.push_back(weight);
|
|
||||||
return xbarWeights.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isXbarsFull() {
|
|
||||||
assert(xbarWeights.size() <= crossbarCountInCore);
|
|
||||||
return xbarWeights.size() == crossbarCountInCore;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isCoreEmpty() { return block->empty(); }
|
|
||||||
|
|
||||||
void dump() {
|
|
||||||
// Print the coreId
|
|
||||||
llvm::outs() << "Core " << coreId << ":\n";
|
|
||||||
// Print the weights
|
|
||||||
llvm::outs() << "Xbar Weights:\n";
|
|
||||||
for (auto weight : xbarWeights)
|
|
||||||
weight.dump();
|
|
||||||
// Print the operands
|
|
||||||
llvm::outs() << "Operands:\n";
|
|
||||||
for (auto operand : operands)
|
|
||||||
llvm::outs() << operand << "\n";
|
|
||||||
|
|
||||||
// Dump the body block
|
|
||||||
for (auto& op : block->getOperations())
|
|
||||||
op.dump();
|
|
||||||
|
|
||||||
// Print the results
|
|
||||||
llvm::outs() << "Results:\n";
|
|
||||||
for (auto result : results)
|
|
||||||
llvm::outs() << result << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t coreId;
|
|
||||||
|
|
||||||
private:
|
|
||||||
ConversionPatternRewriter& rewriter;
|
|
||||||
|
|
||||||
// Should these be set<Value> instead? But I need to keep the order
|
|
||||||
vector<Value> operands;
|
|
||||||
vector<std::shared_ptr<Value>> remappableOperands;
|
|
||||||
|
|
||||||
vector<Value> results;
|
|
||||||
vector<std::shared_ptr<Value>> resultsToRemap;
|
|
||||||
|
|
||||||
// Maps from input tiles to the block operand
|
|
||||||
IRMapping operandMap;
|
|
||||||
|
|
||||||
// Map from outputTileId to MVM operation producing it
|
|
||||||
unordered_map<size_t, Value> outputTileToMVM;
|
|
||||||
|
|
||||||
vector<Value> xbarWeights;
|
|
||||||
|
|
||||||
unique_ptr<mlir::Block> block = make_unique<Block>();
|
|
||||||
|
|
||||||
spatial::SpatWeightedCompute wcomputeOp;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
|
||||||
ONNXConvOpTile(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
struct Producer_t {
|
|
||||||
Value value;
|
|
||||||
shared_ptr<Core> core;
|
|
||||||
};
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
|
|
||||||
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
|
|
||||||
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
|
|
||||||
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
|
|
||||||
|
|
||||||
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
|
|
||||||
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
|
|
||||||
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
|
|
||||||
|
|
||||||
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
|
||||||
if (padUnpackError.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
|
|
||||||
|
|
||||||
// TODO: Pad value at beginning and end of each dimension could be
|
|
||||||
// different. We should handle this case.
|
|
||||||
|
|
||||||
// MapOperations mapOperation = MapOperations::None;
|
|
||||||
//
|
|
||||||
// // If we have just one user, and it is an activation funcion (or more in
|
|
||||||
// // general a mapping operation) just inline it in the computeOps
|
|
||||||
// auto firstUserOp = *conv->getUsers().begin();
|
|
||||||
// if (conv->hasOneUse()) {
|
|
||||||
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
|
|
||||||
//
|
|
||||||
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
|
|
||||||
// return rewriter.notifyMatchFailure(
|
|
||||||
// conv, "Softmax not supported as activation for convolutions.");
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
|
||||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
|
||||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
|
||||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
|
||||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
|
||||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
|
||||||
|
|
||||||
Location loc = conv.getLoc();
|
|
||||||
|
|
||||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
|
||||||
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
|
||||||
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
|
||||||
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
|
|
||||||
|
|
||||||
// Tile the input tensor
|
|
||||||
// Input tiles need to be indexed by:
|
|
||||||
// a. Channel Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: inputTiles[channelTile][x][y]
|
|
||||||
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
|
||||||
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
|
||||||
|
|
||||||
auto resolveErrorOpt = resolveImgInputTiles(
|
|
||||||
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
|
||||||
if (resolveErrorOpt.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(crossbarSize),
|
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
// Tile the weight tensor
|
|
||||||
// Weight tiles need to be indexed by:
|
|
||||||
// a. Filter Tile
|
|
||||||
// b. Channel Tile
|
|
||||||
// c. Kernel `x` position
|
|
||||||
// d. Kernel `y` position
|
|
||||||
// For example: weightTiles[filterTile][channelTile][x][y]
|
|
||||||
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
|
|
||||||
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
|
|
||||||
outputTileCount,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
|
||||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
|
||||||
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
|
||||||
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
|
||||||
sizes = {rewriter.getIndexAttr(crossbarSize),
|
|
||||||
rewriter.getIndexAttr(crossbarSize),
|
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
for (size_t i = 0; i < outputTileCount; i++) {
|
|
||||||
if (i == outputTileCount - 1 && outputTileRemainder != 0)
|
|
||||||
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
|
|
||||||
sizes[1] = rewriter.getIndexAttr(crossbarSize);
|
|
||||||
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
|
|
||||||
for (size_t j = 0; j < inputTileCount; j++) {
|
|
||||||
if (j == inputTileCount - 1 && inputTileRemainder != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
|
|
||||||
for (size_t x = 0; x < krn_w; x++) {
|
|
||||||
for (size_t y = 0; y < krn_h; y++) {
|
|
||||||
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
weightTiles[i][j][x][y] =
|
|
||||||
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Distribute the computation among many compute cores
|
|
||||||
* Try to compute in-core the computation for each output tile, and reduce
|
|
||||||
* over as few cores as possible
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Tile the output tensor
|
|
||||||
// Output tiles need to be indexed by:
|
|
||||||
// a. Filter Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: outputTiles[filterTile][x][y]
|
|
||||||
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
|
|
||||||
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
|
|
||||||
outputTileCount,
|
|
||||||
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
|
||||||
|
|
||||||
size_t replicationFactor;
|
|
||||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
|
|
||||||
replicationFactor = 1;
|
|
||||||
else
|
|
||||||
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
|
||||||
// producers[outTile][out_x][out_y][producerIndex]
|
|
||||||
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
|
|
||||||
outputTileCount,
|
|
||||||
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
|
||||||
|
|
||||||
// Schedule in cores
|
|
||||||
size_t coreId = 0;
|
|
||||||
vector<shared_ptr<Core>> curCores(replicationFactor);
|
|
||||||
for (size_t i = 0; i < replicationFactor; i++)
|
|
||||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
|
||||||
|
|
||||||
vector<shared_ptr<Core>> cores;
|
|
||||||
|
|
||||||
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
|
|
||||||
|
|
||||||
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
|
|
||||||
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
|
|
||||||
|
|
||||||
RankedTensorType mvmOutType =
|
|
||||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
|
|
||||||
|
|
||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
|
||||||
|
|
||||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
|
|
||||||
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
|
|
||||||
|
|
||||||
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
|
|
||||||
|
|
||||||
vector<size_t> xbarIndexes(replicationFactor);
|
|
||||||
for (size_t i = 0; i < replicationFactor; i++)
|
|
||||||
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
|
|
||||||
|
|
||||||
size_t out_x = 0;
|
|
||||||
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
|
|
||||||
size_t out_y = 0;
|
|
||||||
|
|
||||||
// I use `replicationFactor` cores. I divide the input_w into
|
|
||||||
// `replicationFactor` slices, and each slice is distributed to a
|
|
||||||
// core. `coreIndex` is the index of the core that will be used
|
|
||||||
// for this slice
|
|
||||||
size_t coreIndex = in_x / replicationSliceSize;
|
|
||||||
assert(coreIndex < replicationFactor);
|
|
||||||
|
|
||||||
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
|
|
||||||
// Adjust the input based on the kernel
|
|
||||||
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
|
|
||||||
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
|
|
||||||
|
|
||||||
// Check if we are within the input image
|
|
||||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
|
|
||||||
out_y++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
|
|
||||||
auto mvm = curCores[coreIndex]->addMVM(
|
|
||||||
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
|
|
||||||
|
|
||||||
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
|
|
||||||
|
|
||||||
out_y++;
|
|
||||||
}
|
|
||||||
out_x++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computations for these crossbars are done, check if the cores
|
|
||||||
// crossbars are fully used. If full, swap with new core
|
|
||||||
for (size_t i = 0; i < replicationFactor; i++) {
|
|
||||||
if (curCores[i]->isXbarsFull()) {
|
|
||||||
cores.emplace_back(std::move(curCores[i]));
|
|
||||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& curCore : curCores)
|
|
||||||
if (curCore->isCoreEmpty() == false)
|
|
||||||
cores.emplace_back(std::move(curCore));
|
|
||||||
curCores.clear();
|
|
||||||
// Now, do the reduction of each output pixel tile
|
|
||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
|
||||||
for (size_t out_x = 0; out_x < output_w; out_x++) {
|
|
||||||
for (size_t out_y = 0; out_y < output_h; out_y++) {
|
|
||||||
// First, check if some producers are within the same core. If this is
|
|
||||||
// true, `Core::addMVM` have already done the reduction within-core.
|
|
||||||
// This means that we only need to consider the last producer for that
|
|
||||||
// core.
|
|
||||||
|
|
||||||
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
|
|
||||||
for (auto producer : producers[outTile][out_x][out_y])
|
|
||||||
withinCoreReducedProducers[producer.core->coreId] = producer;
|
|
||||||
|
|
||||||
// Now, we need to apply inter-core reduction
|
|
||||||
|
|
||||||
// Base case with one producer
|
|
||||||
if (withinCoreReducedProducers.size() == 1) {
|
|
||||||
// TODO: Add the bias and apply mapping (if present)
|
|
||||||
|
|
||||||
auto singleProducer = withinCoreReducedProducers.begin()->second;
|
|
||||||
// Use last producer as the final result
|
|
||||||
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
|
|
||||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: This is a linear reduction, not a tree reduction. We can do
|
|
||||||
// better: a tree reduction would make more computations happen in
|
|
||||||
// parallel.
|
|
||||||
|
|
||||||
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
|
|
||||||
|
|
||||||
auto it = withinCoreReducedProducers.begin();
|
|
||||||
it++;
|
|
||||||
while (it != withinCoreReducedProducers.end()) {
|
|
||||||
|
|
||||||
Producer_t curProducer = it->second;
|
|
||||||
|
|
||||||
shared_ptr<Core> core1;
|
|
||||||
shared_ptr<Core> core2;
|
|
||||||
Value core1Value;
|
|
||||||
Value core2Value;
|
|
||||||
|
|
||||||
auto lastProducerCoreId = lastProducer.core->coreId;
|
|
||||||
auto curProducerCoreId = curProducer.core->coreId;
|
|
||||||
|
|
||||||
assert(lastProducerCoreId != curProducerCoreId
|
|
||||||
&& "We should have already applied within-core reduction, how "
|
|
||||||
"could we have same cores here?");
|
|
||||||
|
|
||||||
// Sort the cores by coreId
|
|
||||||
if (curProducerCoreId < lastProducerCoreId) {
|
|
||||||
core1 = curProducer.core;
|
|
||||||
core1Value = curProducer.value;
|
|
||||||
core2 = lastProducer.core;
|
|
||||||
core2Value = lastProducer.value;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
core1 = lastProducer.core;
|
|
||||||
core1Value = lastProducer.value;
|
|
||||||
core2 = curProducer.core;
|
|
||||||
core2Value = curProducer.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newCoreRes = core1->makeResultRemappable(core1Value);
|
|
||||||
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(core2Value);
|
|
||||||
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
|
|
||||||
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
|
|
||||||
|
|
||||||
lastProducer = {vaddRes, core2};
|
|
||||||
|
|
||||||
it++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add the bias and apply mapping (if present)
|
|
||||||
|
|
||||||
// Use last producer as the final result
|
|
||||||
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
|
|
||||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
|
|
||||||
rewriter.setInsertionPointAfter(conv);
|
|
||||||
spatial::SpatWeightedCompute lastWComputeOp;
|
|
||||||
for (auto& core : cores) {
|
|
||||||
lastWComputeOp = core->createWComputeOp(loc);
|
|
||||||
core->remapResults();
|
|
||||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& core : cores)
|
|
||||||
core->addRemappedOperands();
|
|
||||||
|
|
||||||
// Set the insertion point after the last WComputeOp.
|
|
||||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
|
||||||
SmallVector<Value> tilesToConcat;
|
|
||||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
|
||||||
for (size_t outX = 0; outX < output_h; outX++)
|
|
||||||
for (size_t outY = 0; outY < output_w; outY++)
|
|
||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
|
||||||
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
|
|
||||||
|
|
||||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
|
|
||||||
|
|
||||||
// Value outputImage =
|
|
||||||
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
|
|
||||||
|
|
||||||
// If no mapping (activation) was applied, just replace ConvOp
|
|
||||||
// if (mapOperation == MapOperations::None) {
|
|
||||||
// rewriter.replaceOp(conv, outputImage);
|
|
||||||
// } else {
|
|
||||||
// // If mapping was applied, erase ConvOp and replace the mapping op
|
|
||||||
// rewriter.eraseOp(conv);
|
|
||||||
// rewriter.replaceOp(firstUserOp, outputImage);
|
|
||||||
// }
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<ONNXConvOpTile>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,400 +0,0 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstddef>
|
|
||||||
#include <unistd.h>
|
|
||||||
|
|
||||||
#include "Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A pattern to tile the convolution operation into a series of compute
|
|
||||||
* units, each one of which applies filters to a subset of the input
|
|
||||||
* tensor. Results are also reduced and concatenated to form the final
|
|
||||||
* output tensor.
|
|
||||||
*/
|
|
||||||
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
|
||||||
ExperimentalONNXConvOpTile(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
|
|
||||||
// --------------------------------- //
|
|
||||||
// --- READ OPERATION PARAMETERS --- //
|
|
||||||
// --------------------------------- //
|
|
||||||
|
|
||||||
// To get each crossbar's weights, we need to slice the weights tensor.
|
|
||||||
// - Along the input tiles.
|
|
||||||
// - Along the output tiles.
|
|
||||||
// - Along the filter x position.
|
|
||||||
// - Along the filter y position.
|
|
||||||
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
|
|
||||||
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
|
|
||||||
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
|
|
||||||
|
|
||||||
// TODO: Address bigger batches.
|
|
||||||
assert(GET_IMAGE_N(inputType) == 1
|
|
||||||
&& "Batch size must be 1"
|
|
||||||
"for convolution.");
|
|
||||||
|
|
||||||
// TODO: Address replication.
|
|
||||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
|
|
||||||
|
|
||||||
// TODO: Address bias addition.
|
|
||||||
|
|
||||||
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
|
|
||||||
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
|
|
||||||
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
|
|
||||||
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
|
|
||||||
|
|
||||||
// Assert that the kernel is square.
|
|
||||||
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- SLICE THE WEIGHTS TENSOR --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// The core idea of this stage is classifying the weights by input and
|
|
||||||
// output tile. This is because we want the applyFilters operations to be
|
|
||||||
// tile agnostic, to keep the subsequent lowering stages as simple as
|
|
||||||
// possible. This data structure does this weight classification:
|
|
||||||
// - The outer map is indexed by input tile.
|
|
||||||
// - The inner map is indexed by output tile.
|
|
||||||
// - The SmallVector contains the weights for the filter.
|
|
||||||
map<long, map<long, SmallVector<Value>>> weightsGroups;
|
|
||||||
|
|
||||||
// During all slicing operations within this stage, we'll use the same
|
|
||||||
// strides for all dimensions.
|
|
||||||
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
|
|
||||||
|
|
||||||
ldiv_t itc = inputTileCount;
|
|
||||||
ldiv_t otc = outputTileCount;
|
|
||||||
|
|
||||||
// - Slicing along the input tiles.
|
|
||||||
// - Slicing along the output tiles.
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
|
|
||||||
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
|
|
||||||
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
|
|
||||||
|
|
||||||
// The loop above also sets the crossbar's used width and height,
|
|
||||||
// checking if we're at the last crossbar and if it's incomplete.
|
|
||||||
|
|
||||||
long outputTile = ot;
|
|
||||||
long inputTile = it;
|
|
||||||
|
|
||||||
// Create the slicing sizes.
|
|
||||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
|
||||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(1),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
// - Slicing along the filter x position.
|
|
||||||
// - Slicing along the filter y position.
|
|
||||||
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
|
|
||||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
|
||||||
|
|
||||||
// Create the slicing offsets.
|
|
||||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
|
||||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(filterX),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(filterY)};
|
|
||||||
|
|
||||||
// Create the slice extraction operation.
|
|
||||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
|
|
||||||
|
|
||||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
|
||||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Tree reduction for compute reduction should be implemented.
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- CREATE ALL COMPUTE UNITS --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// Keep track of input slicing operations to avoid duplication across
|
|
||||||
// all compute units (global slices).
|
|
||||||
map<long, Value> globalSlices;
|
|
||||||
|
|
||||||
// Keep track of all partial compute results.
|
|
||||||
map<long, Value> globalPartialResults;
|
|
||||||
|
|
||||||
// Use a weight subdivider to extract groups of weights for each compute
|
|
||||||
// unit. We'll keep extracting groups until no more weights are left.
|
|
||||||
WeightSubdivider weightSubdivider(weightsGroups);
|
|
||||||
while (!weightSubdivider.isEmpty()) {
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- BEGIN A NEW COMPUTE UNIT --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// Get the next group of weights for the compute unit.
|
|
||||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
|
||||||
|
|
||||||
SmallVector<Value> computeWeights;
|
|
||||||
SmallVector<Value> computeOperands;
|
|
||||||
|
|
||||||
// ------------------------------ //
|
|
||||||
// --- SLICE THE INPUT TENSOR --- //
|
|
||||||
// ------------------------------ //
|
|
||||||
|
|
||||||
// Note each tile's index in the compute unit arguments.
|
|
||||||
map<long, size_t> inputTileIndices;
|
|
||||||
map<long, size_t> outputTileIndices;
|
|
||||||
map<long, size_t> reductionTileIndices; // Incoming partial results.
|
|
||||||
|
|
||||||
// Iterate over all weights groups for this compute unit.
|
|
||||||
map<long, Value> localSlices; // WRT the current compute unit.
|
|
||||||
for (auto group : weightsGroups) {
|
|
||||||
for (Value weight : group.weights)
|
|
||||||
computeWeights.push_back(weight);
|
|
||||||
|
|
||||||
// There might be multiple weight groups for the same input tile, so if
|
|
||||||
// we've already added the input tile, skip it.
|
|
||||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// We might have already sliced the input tensor for some other compute
|
|
||||||
// unit, so if we have, reuse the slicing operation without creating a
|
|
||||||
// new one.
|
|
||||||
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
|
|
||||||
computeOperands.push_back(globalSlices[group.inputTile]);
|
|
||||||
localSlices[group.inputTile] = globalSlices[group.inputTile];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the input tensor slicing offsets.
|
|
||||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
|
||||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(0),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(0)};
|
|
||||||
|
|
||||||
// Create the input tensor slicing sizes.
|
|
||||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
|
||||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
|
||||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
|
|
||||||
|
|
||||||
// Create the slice extraction operation.
|
|
||||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
|
|
||||||
|
|
||||||
computeOperands.push_back(extractSliceOp);
|
|
||||||
|
|
||||||
// Update slicing maps.
|
|
||||||
globalSlices[group.inputTile] = extractSliceOp;
|
|
||||||
localSlices[group.inputTile] = extractSliceOp;
|
|
||||||
|
|
||||||
// Update the input tile index.
|
|
||||||
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------- //
|
|
||||||
// --- PREPARE THE OUTPUT TYPE --- //
|
|
||||||
// ------------------------------- //
|
|
||||||
|
|
||||||
// Fill the compute output's type by looking at the output tiles.
|
|
||||||
SmallVector<Type> computeOutputType;
|
|
||||||
for (TaggedWeights group : weightsGroups) {
|
|
||||||
|
|
||||||
// There might be multiple weight groups for the same output tile, so if
|
|
||||||
// we've already added the output tile, skip it.
|
|
||||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Additionally, after adding the input slices as operands, also add any
|
|
||||||
// compatible partial results from previous compute units.
|
|
||||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
|
||||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
|
||||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define the output shape for this group.
|
|
||||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
|
||||||
|
|
||||||
// TODO: Address non-same padding.
|
|
||||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
|
||||||
/* 1 */ outputTileSize,
|
|
||||||
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
|
|
||||||
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
|
|
||||||
|
|
||||||
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
|
|
||||||
|
|
||||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
|
||||||
|
|
||||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----------------------------- //
|
|
||||||
// --- FILL THE COMPUTE UNIT --- //
|
|
||||||
// ----------------------------- //
|
|
||||||
|
|
||||||
// Create the compute unit.
|
|
||||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
|
||||||
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
|
|
||||||
|
|
||||||
// Create a new block for the compute unit and add the operands.
|
|
||||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
for (Value operand : computeOperands)
|
|
||||||
block->addArgument(operand.getType(), conv->getLoc());
|
|
||||||
|
|
||||||
// Initialize a map of local partial results.
|
|
||||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
|
||||||
|
|
||||||
// If we have any reduction tiles, add them to the local partial results.
|
|
||||||
for (auto reductionTileIndex : reductionTileIndices)
|
|
||||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
|
||||||
|
|
||||||
// Add all the applyFilters operations to the block.
|
|
||||||
for (TaggedWeights group : weightsGroups) {
|
|
||||||
|
|
||||||
// Get the outputType for this group.
|
|
||||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
|
||||||
|
|
||||||
// Create an apply filters operation.
|
|
||||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
|
||||||
|
|
||||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
|
||||||
// ... As many weights as the size of group.weights.
|
|
||||||
SmallVector<long> weightIndices;
|
|
||||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
|
||||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
|
||||||
|
|
||||||
SmallVector<int64_t> xKerPos;
|
|
||||||
SmallVector<int64_t> yKerPos;
|
|
||||||
for (auto weight : group.weights) {
|
|
||||||
// Assert that the weight is an extract_slice operation.
|
|
||||||
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
|
|
||||||
assert(extractSliceOp && "Weight is not an extract_slice operation.");
|
|
||||||
|
|
||||||
// Get the filter x and y positions from the extract_slice operation.
|
|
||||||
auto offsets = extractSliceOp.getStaticOffsets();
|
|
||||||
xKerPos.push_back(offsets[2]);
|
|
||||||
yKerPos.push_back(offsets[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
|
|
||||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
|
||||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
|
||||||
|
|
||||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
|
||||||
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
|
||||||
|
|
||||||
// Perform local reduction if necessary.
|
|
||||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
|
||||||
|
|
||||||
result = rewriter.create<spatial::SpatVAddOp>(
|
|
||||||
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the partial results map.
|
|
||||||
localPartialResults[group.outputTile] = result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a yield operation to the block by concatenating the partial
|
|
||||||
// results.
|
|
||||||
SmallVector<Value> applyFiltersResults;
|
|
||||||
for (size_t i = 0; i < computeOutputType.size(); ++i) {
|
|
||||||
long outputTile;
|
|
||||||
|
|
||||||
// Given an output tile index, find the corresponding output tile.
|
|
||||||
for (auto outputTileIndex : outputTileIndices) {
|
|
||||||
if (outputTileIndex.second == i) {
|
|
||||||
outputTile = outputTileIndex.first;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get that tile's partial result and add it to the list.
|
|
||||||
applyFiltersResults.push_back(localPartialResults[outputTile]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the yield operation with the given results.
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
|
|
||||||
|
|
||||||
// Update the global partial results map.
|
|
||||||
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
|
|
||||||
long outputTile;
|
|
||||||
|
|
||||||
// Given an output tile index, find the corresponding output tile.
|
|
||||||
for (auto outputTileIndex : outputTileIndices) {
|
|
||||||
if (outputTileIndex.second == i) {
|
|
||||||
outputTile = outputTileIndex.first;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
globalPartialResults[outputTile] = currentCompute.getResult(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move the rewrite cursor out of the block.
|
|
||||||
rewriter.setInsertionPointAfter(currentCompute);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------ //
|
|
||||||
// --- CONCATENATE THE OUTPUT --- //
|
|
||||||
// ------------------------------ //
|
|
||||||
|
|
||||||
// Turn the values into a SmallVector.
|
|
||||||
SmallVector<Value> outputValues;
|
|
||||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
|
||||||
outputValues.push_back(globalPartialResults[i]);
|
|
||||||
|
|
||||||
// Assert that the number of output values is correct.
|
|
||||||
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
|
|
||||||
|
|
||||||
// If the conv's user is a ReLU...
|
|
||||||
if (conv->hasOneUse()) {
|
|
||||||
Operation* user = *conv->getUsers().begin();
|
|
||||||
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
|
|
||||||
// ...then we can just replace the ReLU with the concatenation.
|
|
||||||
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
|
||||||
|
|
||||||
// And erase the convolution.
|
|
||||||
rewriter.eraseOp(conv);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the final output.
|
|
||||||
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Populate the tiling pattern for a convolution operation.
|
|
||||||
*
|
|
||||||
* @param patterns The pattern set to populate.
|
|
||||||
* @param ctx The MLIR context.
|
|
||||||
*/
|
|
||||||
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,365 +0,0 @@
|
|||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
|
|
||||||
#include "Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
|
||||||
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
|
|
||||||
ExperimentalGemmConversionPattern(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
|
|
||||||
// --------------------------------- //
|
|
||||||
// --- READ OPERATION PARAMETERS --- //
|
|
||||||
// --------------------------------- //
|
|
||||||
|
|
||||||
// To get each crossbar's weights, we need to slice the weights tensor.
|
|
||||||
// - Along the input tiles.
|
|
||||||
// - Along the output tiles.
|
|
||||||
// - Along the filter x position.
|
|
||||||
// - Along the filter y position.
|
|
||||||
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
|
|
||||||
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
|
|
||||||
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
|
|
||||||
|
|
||||||
// TODO: Address bigger batches.
|
|
||||||
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
|
|
||||||
|
|
||||||
// TODO: Address replication.
|
|
||||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
|
|
||||||
|
|
||||||
// TODO: Address bias addition.
|
|
||||||
|
|
||||||
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
|
|
||||||
|
|
||||||
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
|
|
||||||
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
|
|
||||||
size_t kernelWidth = 1;
|
|
||||||
size_t kernelHeight = 1;
|
|
||||||
|
|
||||||
// Assert that the kernel is square.
|
|
||||||
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- SLICE THE WEIGHTS TENSOR --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// The core idea of this stage is classifying the weights by input and
|
|
||||||
// output tile. This is because we want the applyFilters operations to be
|
|
||||||
// tile agnostic, to keep the subsequent lowering stages as simple as
|
|
||||||
// possible. This data structure does this weight classification:
|
|
||||||
// - The outer map is indexed by input tile.
|
|
||||||
// - The inner map is indexed by output tile.
|
|
||||||
// - The SmallVector contains the weights for the filter.
|
|
||||||
map<long, map<long, SmallVector<Value>>> weightsGroups;
|
|
||||||
|
|
||||||
// During all slicing operations within this stage, we'll use the same
|
|
||||||
// strides for all dimensions.
|
|
||||||
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
|
|
||||||
|
|
||||||
ldiv_t itc = inputTileCount;
|
|
||||||
ldiv_t otc = outputTileCount;
|
|
||||||
|
|
||||||
// - Slicing along the input tiles.
|
|
||||||
// - Slicing along the output tiles.
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
|
|
||||||
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
|
|
||||||
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
|
|
||||||
|
|
||||||
// The loop above also sets the crossbar's used width and height,
|
|
||||||
// checking if we're at the last crossbar and if it's incomplete.
|
|
||||||
|
|
||||||
long outputTile = ot;
|
|
||||||
long inputTile = it;
|
|
||||||
|
|
||||||
// Create the slicing sizes.
|
|
||||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
|
||||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
|
||||||
/* 2 */ /* rewriter.getIndexAttr(1), */
|
|
||||||
/* 3 */ /* rewriter.getIndexAttr(1) */};
|
|
||||||
|
|
||||||
// - Slicing along the filter x position.
|
|
||||||
// - Slicing along the filter y position.
|
|
||||||
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
|
|
||||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
|
||||||
|
|
||||||
// Create the slicing offsets.
|
|
||||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
|
||||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
|
||||||
/* 2 */ /* rewriter.getIndexAttr(filterX), */
|
|
||||||
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
|
|
||||||
|
|
||||||
// Create the slice extraction operation.
|
|
||||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
|
|
||||||
|
|
||||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
|
||||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Tree reduction for compute reduction should be implemented.
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- CREATE ALL COMPUTE UNITS --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// Keep track of input slicing operations to avoid duplication across
|
|
||||||
// all compute units (global slices).
|
|
||||||
map<long, Value> globalSlices;
|
|
||||||
|
|
||||||
// Keep track of all partial compute results.
|
|
||||||
map<long, Value> globalPartialResults;
|
|
||||||
|
|
||||||
// Use a weight subdivider to extract groups of weights for each compute
|
|
||||||
// unit. We'll keep extracting groups until no more weights are left.
|
|
||||||
WeightSubdivider weightSubdivider(weightsGroups);
|
|
||||||
while (!weightSubdivider.isEmpty()) {
|
|
||||||
|
|
||||||
// -------------------------------- //
|
|
||||||
// --- BEGIN A NEW COMPUTE UNIT --- //
|
|
||||||
// -------------------------------- //
|
|
||||||
|
|
||||||
// Get the next group of weights for the compute unit.
|
|
||||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
|
||||||
|
|
||||||
SmallVector<Value> computeWeights;
|
|
||||||
SmallVector<Value> computeOperands;
|
|
||||||
|
|
||||||
// ------------------------------ //
|
|
||||||
// --- SLICE THE INPUT TENSOR --- //
|
|
||||||
// ------------------------------ //
|
|
||||||
|
|
||||||
// Note each tile's index in the compute unit arguments.
|
|
||||||
map<long, size_t> inputTileIndices;
|
|
||||||
map<long, size_t> outputTileIndices;
|
|
||||||
map<long, size_t> reductionTileIndices; // Incoming partial results.
|
|
||||||
|
|
||||||
// Iterate over all weights groups for this compute unit.
|
|
||||||
map<long, Value> localSlices; // WRT the current compute unit.
|
|
||||||
for (auto group : weightsGroups) {
|
|
||||||
for (Value weight : group.weights)
|
|
||||||
computeWeights.push_back(weight);
|
|
||||||
|
|
||||||
// There might be multiple weight groups for the same input tile, so if
|
|
||||||
// we've already added the input tile, skip it.
|
|
||||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// We might have already sliced the input tensor for some other compute
|
|
||||||
// unit, so if we have, reuse the slicing operation without creating a
|
|
||||||
// new one.
|
|
||||||
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
|
|
||||||
computeOperands.push_back(globalSlices[group.inputTile]);
|
|
||||||
localSlices[group.inputTile] = globalSlices[group.inputTile];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the input tensor slicing offsets.
|
|
||||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
|
||||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
|
||||||
/* 2 */ /* rewriter.getIndexAttr(0), */
|
|
||||||
/* 3 */ /* rewriter.getIndexAttr(0) */};
|
|
||||||
|
|
||||||
// Create the input tensor slicing sizes.
|
|
||||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
|
||||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
|
||||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
|
||||||
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
|
|
||||||
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
|
|
||||||
|
|
||||||
// Create the slice extraction operation.
|
|
||||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
|
|
||||||
|
|
||||||
computeOperands.push_back(extractSliceOp);
|
|
||||||
|
|
||||||
// Update slicing maps.
|
|
||||||
globalSlices[group.inputTile] = extractSliceOp;
|
|
||||||
localSlices[group.inputTile] = extractSliceOp;
|
|
||||||
|
|
||||||
// Update the input tile index.
|
|
||||||
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------- //
|
|
||||||
// --- PREPARE THE OUTPUT TYPE --- //
|
|
||||||
// ------------------------------- //
|
|
||||||
|
|
||||||
// Fill the compute output's type by looking at the output tiles.
|
|
||||||
SmallVector<Type> computeOutputType;
|
|
||||||
for (TaggedWeights group : weightsGroups) {
|
|
||||||
|
|
||||||
// There might be multiple weight groups for the same output tile, so if
|
|
||||||
// we've already added the output tile, skip it.
|
|
||||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Additionally, after adding the input slices as operands, also add any
|
|
||||||
// compatible partial results from previous compute units.
|
|
||||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
|
||||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
|
||||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define the output shape for this group.
|
|
||||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
|
||||||
|
|
||||||
// TODO: Address non-same padding.
|
|
||||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
|
||||||
/* 1 */ outputTileSize,
|
|
||||||
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
|
|
||||||
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
|
|
||||||
|
|
||||||
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
|
|
||||||
|
|
||||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
|
||||||
|
|
||||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ----------------------------- //
|
|
||||||
// --- FILL THE COMPUTE UNIT --- //
|
|
||||||
// ----------------------------- //
|
|
||||||
|
|
||||||
// Create the compute unit.
|
|
||||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
|
||||||
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
|
|
||||||
|
|
||||||
// Create a new block for the compute unit and add the operands.
|
|
||||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
for (Value operand : computeOperands)
|
|
||||||
block->addArgument(operand.getType(), gemmOp->getLoc());
|
|
||||||
|
|
||||||
// Initialize a map of local partial results.
|
|
||||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
|
||||||
|
|
||||||
// If we have any reduction tiles, add them to the local partial results.
|
|
||||||
for (auto reductionTileIndex : reductionTileIndices)
|
|
||||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
|
||||||
|
|
||||||
// Add all the applyFilters operations to the block.
|
|
||||||
for (TaggedWeights group : weightsGroups) {
|
|
||||||
|
|
||||||
// Get the outputType for this group.
|
|
||||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
|
||||||
|
|
||||||
// Create an apply filters operation.
|
|
||||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
|
||||||
|
|
||||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
|
||||||
// ... As many weights as the size of group.weights.
|
|
||||||
SmallVector<long> weightIndices;
|
|
||||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
|
||||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
|
||||||
|
|
||||||
SmallVector<int64_t> xKerPos;
|
|
||||||
SmallVector<int64_t> yKerPos;
|
|
||||||
for (auto weight : group.weights) {
|
|
||||||
// Assert that the weight is an extract_slice operation.
|
|
||||||
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
|
|
||||||
assert(extractSliceOp && "Weight is not an extract_slice operation.");
|
|
||||||
|
|
||||||
// Get the filter x and y positions from the extract_slice operation.
|
|
||||||
xKerPos.push_back(0);
|
|
||||||
yKerPos.push_back(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
|
|
||||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
|
||||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
|
||||||
|
|
||||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
|
||||||
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
|
||||||
|
|
||||||
// Perform local reduction if necessary.
|
|
||||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
|
||||||
|
|
||||||
result = rewriter.create<spatial::SpatVAddOp>(
|
|
||||||
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the partial results map.
|
|
||||||
localPartialResults[group.outputTile] = result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add a yield operation to the block by concatenating the partial
|
|
||||||
// results.
|
|
||||||
SmallVector<Value> applyFiltersResults;
|
|
||||||
for (size_t i = 0; i < computeOutputType.size(); ++i) {
|
|
||||||
long outputTile;
|
|
||||||
|
|
||||||
// Given an output tile index, find the corresponding output tile.
|
|
||||||
for (auto outputTileIndex : outputTileIndices) {
|
|
||||||
if (outputTileIndex.second == i) {
|
|
||||||
outputTile = outputTileIndex.first;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get that tile's partial result and add it to the list.
|
|
||||||
applyFiltersResults.push_back(localPartialResults[outputTile]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the yield operation with the given results.
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
|
|
||||||
|
|
||||||
// Update the global partial results map.
|
|
||||||
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
|
|
||||||
long outputTile;
|
|
||||||
|
|
||||||
// Given an output tile index, find the corresponding output tile.
|
|
||||||
for (auto outputTileIndex : outputTileIndices) {
|
|
||||||
if (outputTileIndex.second == i) {
|
|
||||||
outputTile = outputTileIndex.first;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
globalPartialResults[outputTile] = currentCompute.getResult(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move the rewrite cursor out of the block.
|
|
||||||
rewriter.setInsertionPointAfter(currentCompute);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------ //
|
|
||||||
// --- CONCATENATE THE OUTPUT --- //
|
|
||||||
// ------------------------------ //
|
|
||||||
|
|
||||||
// Turn the values into a SmallVector.
|
|
||||||
SmallVector<Value> outputValues;
|
|
||||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
|
||||||
outputValues.push_back(globalPartialResults[i]);
|
|
||||||
|
|
||||||
// Assert that the number of output values is correct.
|
|
||||||
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
|
|
||||||
|
|
||||||
// Return the final output.
|
|
||||||
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
|
||||||
|
|
||||||
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|
||||||
ONNXGemmOpTile(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
Location gemmLoc = gemmOp.getLoc();
|
|
||||||
Value a = adaptor.getA();
|
|
||||||
Value b = adaptor.getB();
|
|
||||||
Value c = adaptor.getC();
|
|
||||||
Value out = gemmOp.getY();
|
|
||||||
|
|
||||||
float alpha = adaptor.getAlpha().convertToFloat();
|
|
||||||
float beta = adaptor.getBeta().convertToFloat();
|
|
||||||
bool transA = adaptor.getTransA();
|
|
||||||
bool transB = adaptor.getTransB();
|
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
|
||||||
auto bType = cast<RankedTensorType>(b.getType());
|
|
||||||
auto outType = cast<RankedTensorType>(out.getType());
|
|
||||||
|
|
||||||
RankedTensorType cType = nullptr;
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
||||||
if (hasC) {
|
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
|
||||||
assert("Only support 2 tensor for C" && cType.getRank() == 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
|
||||||
|
|
||||||
if (transA) {
|
|
||||||
auto aShape = aType.getShape();
|
|
||||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
|
||||||
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
}
|
|
||||||
if (transB) {
|
|
||||||
auto bShape = bType.getShape();
|
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
|
||||||
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (alpha != 1.0f) {
|
|
||||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
|
||||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
|
||||||
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
|
|
||||||
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
|
|
||||||
}
|
|
||||||
if (hasC && beta != 1.0f) {
|
|
||||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
|
||||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
|
||||||
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
|
|
||||||
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
|
||||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
|
||||||
auto bNumVSlices = aNumHSlices;
|
|
||||||
auto bLastVSliceSize = aLastHSliceSize;
|
|
||||||
auto cNumHSlices = bNumHSlices;
|
|
||||||
auto cLastHSliceSize = bLastHSliceSize;
|
|
||||||
auto outNumHSlices = cNumHSlices;
|
|
||||||
auto outLastHSliceSize = cLastHSliceSize;
|
|
||||||
|
|
||||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
|
||||||
|
|
||||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
|
||||||
|
|
||||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
|
||||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
|
||||||
|
|
||||||
SmallVector<Value> cHSlices;
|
|
||||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
|
||||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
|
||||||
if (hasC)
|
|
||||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
|
||||||
|
|
||||||
RankedTensorType outHSliceType =
|
|
||||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
|
||||||
RankedTensorType outLastHSliceType =
|
|
||||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
|
||||||
|
|
||||||
SmallVector<Value> outHSlices;
|
|
||||||
outHSlices.reserve(outNumHSlices);
|
|
||||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
|
||||||
RankedTensorType currOutHSliceType = outHSliceType;
|
|
||||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
|
||||||
currOutHSliceType = outLastHSliceType;
|
|
||||||
|
|
||||||
SmallVector<Value> partialResults;
|
|
||||||
partialResults.reserve(coresPerVSlice);
|
|
||||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
|
||||||
SmallVector<Value> weights;
|
|
||||||
weights.reserve(aHSlices[coreId].size());
|
|
||||||
|
|
||||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
|
||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
|
||||||
|
|
||||||
auto computeOp =
|
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
|
||||||
|
|
||||||
auto* computeBlock = new Block();
|
|
||||||
for (auto aHSlice : aHSlices[coreId])
|
|
||||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
|
||||||
computeOp.getBody().push_back(computeBlock);
|
|
||||||
rewriter.setInsertionPointToStart(computeBlock);
|
|
||||||
|
|
||||||
auto computeArgs = computeBlock->getArguments();
|
|
||||||
SmallVector<Value> vmmOutputs;
|
|
||||||
vmmOutputs.reserve(computeArgs.size());
|
|
||||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
|
||||||
vmmOutputs.push_back(
|
|
||||||
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasC) {
|
|
||||||
Value cHSlice = cHSlices[outSliceId];
|
|
||||||
partialResults.push_back(cHSlice);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto reduceComputeOp =
|
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
|
||||||
|
|
||||||
auto* reduceBlock = new Block();
|
|
||||||
for (auto partialResult : partialResults)
|
|
||||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
|
||||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
|
||||||
rewriter.setInsertionPointToStart(reduceBlock);
|
|
||||||
|
|
||||||
auto blockArgs = reduceBlock->getArguments();
|
|
||||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
|
|
||||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
|
||||||
|
|
||||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(gemmOp);
|
|
||||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
|
|
||||||
rewriter.replaceOp(gemmOp, concatOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/**
|
|
||||||
* Resolves the ONNXExpOp from the use chain of the given start value.
|
|
||||||
*
|
|
||||||
* This function traverses the use chain of the start value until it finds an
|
|
||||||
* ONNXExpOp. It returns the value of the ONNXExpOp.
|
|
||||||
*
|
|
||||||
* @param startValue The starting value of the use chain.
|
|
||||||
* @return The value of the ONNXExpOp found in the use chain.
|
|
||||||
*/
|
|
||||||
static Value resolveONNXExpOpFromUseChain(Value startValue) {
|
|
||||||
Value walker = startValue;
|
|
||||||
|
|
||||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
|
||||||
walker = walker.getDefiningOp()->getOperand(0);
|
|
||||||
|
|
||||||
assert(walker && walker.getDefiningOp()
|
|
||||||
&& "Unwinded the whole chain of operations while trying to "
|
|
||||||
"find ONNXExpOp, but did not find it");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
|
||||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
|
||||||
&& "Old output tile (softmax reducer) is not produced by an "
|
|
||||||
"ONNXExpOp");
|
|
||||||
|
|
||||||
return walker;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Softmax is a special case, as it requires another reduction after the
|
|
||||||
// first one. In the cores, `applyReducePattern` already applied
|
|
||||||
// f(x) = exp(x) to each tile. This mean that now we just need to
|
|
||||||
// reduce-sum these tiles, and then divide each tile by the reduced sum,
|
|
||||||
// which is propagated back to the cores via a broadcast channel.
|
|
||||||
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
|
||||||
Value& softmaxChannel,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
SpatialReducer& reducer,
|
|
||||||
ONNXGemmOp& gemmOp,
|
|
||||||
Location& loc) const {
|
|
||||||
|
|
||||||
// TODO: Check case with one compute op
|
|
||||||
|
|
||||||
// Cast vector of Value into vector of ComputeOp
|
|
||||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
|
||||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
|
||||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
|
||||||
}));
|
|
||||||
|
|
||||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
|
||||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
|
||||||
|
|
||||||
reducer.applyReducePattern(
|
|
||||||
softmaxOpsToReduce,
|
|
||||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
|
|
||||||
/* preprocess = */
|
|
||||||
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
|
|
||||||
[&](Value softmaxDivisor) {
|
|
||||||
// Signal that this is the compute with the softmax divisor
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
|
||||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
|
||||||
|
|
||||||
// Broadcast the divisor to all the cores
|
|
||||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
|
||||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* softmaxDividend = onnx.exp (...)
|
|
||||||
* sum = spat.SumOp(softmaxDividend)
|
|
||||||
* [following can be repeated N times, thus walk the use chain]
|
|
||||||
* softmaxDivisor = spat.sadd(sum, ...)
|
|
||||||
*/
|
|
||||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
|
||||||
|
|
||||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
|
||||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
|
||||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
|
||||||
|
|
||||||
// Do not divide here, divide after this
|
|
||||||
return softmaxDivisor;
|
|
||||||
});
|
|
||||||
|
|
||||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
|
||||||
// the reduced denominator.
|
|
||||||
outputOpsAndResNums.clear();
|
|
||||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
|
||||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
|
||||||
|
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
|
||||||
|
|
||||||
Value divisor;
|
|
||||||
|
|
||||||
// Check if this compute contains the softmax divisor: if so, find the
|
|
||||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
|
||||||
// using ChannelBroadcastReceiveOp
|
|
||||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
|
||||||
|
|
||||||
bool found = false;
|
|
||||||
for (auto broadcastOp :
|
|
||||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
|
||||||
assert(found == false
|
|
||||||
&& "More than one ChannelBroadcastSendOp in "
|
|
||||||
"compute? How is this possible?");
|
|
||||||
found = true;
|
|
||||||
|
|
||||||
divisor = broadcastOp.getData();
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(found
|
|
||||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
|
||||||
"divisor was specified to be?");
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
|
||||||
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
|
||||||
// needed because some some may have a different amount of `VAddOp`s due
|
|
||||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
|
||||||
// multiples)
|
|
||||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
|
||||||
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
|
||||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
|
||||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
|
||||||
|
|
||||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<ONNXGemmOpTile>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstddef>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
bool hasPostProcessExperimentalPoolingWindow() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
PoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
ONNXAveragePoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
||||||
|
|
||||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
|
||||||
|
|
||||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
|
||||||
|
|
||||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
|
||||||
// compatible with the current code generation, which assumes constant to be
|
|
||||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
|
||||||
// directly under func.func (i.e. alongside ComputeOps)
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
|
||||||
rewriter.setInsertionPoint(computeOp);
|
|
||||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
|
||||||
scalarTensor,
|
|
||||||
rewriter.getI64IntegerAttr(divisorNumber),
|
|
||||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
|
||||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ReductionOp>
|
|
||||||
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
|
|
||||||
if (inputTiles.size() == 1)
|
|
||||||
return inputTiles[0];
|
|
||||||
|
|
||||||
if (inputTiles.size() == 2) {
|
|
||||||
return rewriter.create<spatial::SpatVMaxOp>(
|
|
||||||
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
|
|
||||||
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
|
|
||||||
|
|
||||||
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
|
|
||||||
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
|
|
||||||
|
|
||||||
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
|
||||||
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|
||||||
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern<PoolOp>(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
Value X = adaptor.getX();
|
|
||||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
|
||||||
Value Y = poolOp.getResult();
|
|
||||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
|
||||||
|
|
||||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
|
||||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
|
||||||
|
|
||||||
if (adaptor.getAutoPad() != "NOTSET")
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
|
||||||
|
|
||||||
size_t pad_x, pad_y;
|
|
||||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
|
||||||
if (padUnpackError.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
|
||||||
|
|
||||||
Location loc = poolOp.getLoc();
|
|
||||||
|
|
||||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
|
||||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
|
||||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
|
||||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
|
||||||
|
|
||||||
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
|
|
||||||
|
|
||||||
// Assert that the input is a tensor.ConcatOp.
|
|
||||||
auto concat = X.getDefiningOp<tensor::ConcatOp>();
|
|
||||||
if (!concat)
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
|
|
||||||
|
|
||||||
// Create a [channel_tile][x][y] array to store the input tiles.
|
|
||||||
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
|
|
||||||
|
|
||||||
// For each argument of the tensor.ConcatOp, resolve the input tiles.
|
|
||||||
for (size_t y = 0; y < input_h; ++y) {
|
|
||||||
for (size_t x = 0; x < input_w; ++x) {
|
|
||||||
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
|
|
||||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
|
|
||||||
/* 1 */ rewriter.getIndexAttr(0),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(x),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(y)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
|
||||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
|
||||||
/* 2 */ rewriter.getIndexAttr(1),
|
|
||||||
/* 3 */ rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
// Get the concat's operand that we want to slice.
|
|
||||||
Value concatInput = concat.getOperand(it);
|
|
||||||
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
|
|
||||||
|
|
||||||
inputTiles[it][x][y] = slicedTile;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the shape of the compute's output.
|
|
||||||
ldiv_t itc = tileCount;
|
|
||||||
SmallVector<Type> outputTileTypes;
|
|
||||||
for (size_t y = 0; y < output_h; ++y) {
|
|
||||||
for (size_t x = 0; x < output_w; ++x) {
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
|
||||||
/* 1 */
|
|
||||||
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
|
|
||||||
/* 2 */ 1,
|
|
||||||
/* 3 */ 1};
|
|
||||||
|
|
||||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
|
||||||
|
|
||||||
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a plain value list of the input tiles.
|
|
||||||
SmallVector<Value> inputTilesList;
|
|
||||||
for (size_t y = 0; y < input_h; ++y) {
|
|
||||||
for (size_t x = 0; x < input_w; ++x)
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
|
|
||||||
inputTilesList.push_back(inputTiles[it][y][x]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a single compute to calculate the output.
|
|
||||||
auto computeOp =
|
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
|
|
||||||
|
|
||||||
// Create a new block for the compute unit and add the operands.
|
|
||||||
Block* block = rewriter.createBlock(&computeOp.getRegion());
|
|
||||||
|
|
||||||
// Fill the block arguments and keep a reference to them.
|
|
||||||
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
|
|
||||||
for (size_t y = 0; y < input_h; ++y) {
|
|
||||||
for (size_t x = 0; x < input_w; ++x) {
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
|
||||||
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Begin writing in the block.
|
|
||||||
rewriter.setInsertionPointToStart(block);
|
|
||||||
|
|
||||||
// Go through all pooling blocks.
|
|
||||||
SmallVector<Value> outputTiles;
|
|
||||||
for (size_t y = 0; y < output_h; ++y) {
|
|
||||||
for (size_t x = 0; x < output_w; ++x) {
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
size_t start_x = x * stride_x;
|
|
||||||
size_t start_y = y * stride_y;
|
|
||||||
size_t end_x = std::min(start_x + krn_w, input_w);
|
|
||||||
size_t end_y = std::min(start_y + krn_h, input_h);
|
|
||||||
|
|
||||||
SmallVector<Value> inputTilesToReduce;
|
|
||||||
for (size_t ky = start_y; ky < end_y; ++ky)
|
|
||||||
for (size_t kx = start_x; kx < end_x; ++kx)
|
|
||||||
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
|
|
||||||
|
|
||||||
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
|
|
||||||
|
|
||||||
// If the reduce op is add, we need to divide the result by the
|
|
||||||
// number of elements in the pooling window.
|
|
||||||
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
|
|
||||||
// Add a spat.const before the computeOp.
|
|
||||||
rewriter.setInsertionPoint(computeOp);
|
|
||||||
auto divisorValue =
|
|
||||||
rewriter.create<spatial::SpatConstantOp>(loc,
|
|
||||||
RankedTensorType::get({1}, rewriter.getF32Type()),
|
|
||||||
rewriter.getI64IntegerAttr(krn_w * krn_h),
|
|
||||||
rewriter.getBoolAttr(true));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
|
|
||||||
reduceResult =
|
|
||||||
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
|
|
||||||
}
|
|
||||||
outputTiles.push_back(reduceResult);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a YieldOp to return the output tiles.
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
|
|
||||||
|
|
||||||
// Set the rewrite cursor right after the computeOp.
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
|
|
||||||
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
|
|
||||||
for (size_t y = 0; y < output_h; ++y) {
|
|
||||||
for (size_t x = 0; x < output_w; ++x) {
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
|
||||||
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now create spat.img.concat ops to concatenate the output tiles.
|
|
||||||
SmallVector<Value> outputTilesList;
|
|
||||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
|
||||||
SmallVector<Value> imgConcatTiles;
|
|
||||||
for (size_t y = 0; y < output_h; ++y)
|
|
||||||
for (size_t x = 0; x < output_w; ++x)
|
|
||||||
imgConcatTiles.push_back(computeOutput[it][y][x]);
|
|
||||||
|
|
||||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
|
||||||
|
|
||||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
|
||||||
/* 1 */ (long) tilingSize,
|
|
||||||
/* 2 */ (long) output_w,
|
|
||||||
/* 3 */ (long) output_h};
|
|
||||||
|
|
||||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
|
||||||
|
|
||||||
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
|
|
||||||
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor.ConcatOp to concatenate the output tiles.
|
|
||||||
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
|
|
||||||
|
|
||||||
rewriter.replaceOp(poolOp, outputTensor);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<
|
|
||||||
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
|
|
||||||
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
|
|
||||||
ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,427 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstddef>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
|
||||||
|
|
||||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
std::function<Value(const Value&, const Value&)> reduce,
|
|
||||||
std::function<Value(const Value&)> preprocess,
|
|
||||||
std::function<Value(const Value&)> postprocess) {
|
|
||||||
// Simple case: if we have only one input, just return it
|
|
||||||
if (valuesToReduce.size() == 1)
|
|
||||||
return valuesToReduce[0];
|
|
||||||
|
|
||||||
if (preprocess) {
|
|
||||||
for (auto& valToReduce : valuesToReduce) {
|
|
||||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
|
||||||
valToReduce = preprocess(valToReduce);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// It is possible that `valuesToReduce` contains two entries for the same
|
|
||||||
// computeOp. In this case, we need to apply the reduction within-computef
|
|
||||||
|
|
||||||
// Keep a map between a computeOp and the last Value for this reduction
|
|
||||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
|
||||||
for (auto& valToReduce : valuesToReduce) {
|
|
||||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
|
||||||
// if (valToReduce.getDefiningOp()) {
|
|
||||||
// // If the value is defined by an operation, we take the parent
|
|
||||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
|
||||||
// } else {
|
|
||||||
// // Otherwise it is a block argument,
|
|
||||||
// computeOp->getBlock()->getParentOp();
|
|
||||||
// }
|
|
||||||
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
|
||||||
|
|
||||||
auto it = lastValueForCompute.find(computeOp);
|
|
||||||
|
|
||||||
if (it != lastValueForCompute.end()) {
|
|
||||||
// If we have already seen this computeOp, apply the reduction
|
|
||||||
// within-compute
|
|
||||||
Value lastWithinComputeValue = it->second;
|
|
||||||
|
|
||||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
|
||||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
|
||||||
else
|
|
||||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
|
||||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
|
||||||
lastValueForCompute[computeOp] = valToReduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
lastValueForCompute[computeOp] = valToReduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, reconstruct from the map the valuesToReduce list
|
|
||||||
valuesToReduce.clear();
|
|
||||||
valuesToReduce.reserve(lastValueForCompute.size());
|
|
||||||
for (auto& entry : lastValueForCompute)
|
|
||||||
valuesToReduce.push_back(entry.second);
|
|
||||||
|
|
||||||
Location loc = valuesToReduce[0].getLoc();
|
|
||||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
|
||||||
|
|
||||||
// Recursive algorithm to reduce the inputs to a single one:
|
|
||||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
|
||||||
// the valuesToReduce list which becomes half the size.
|
|
||||||
// - Repeat until there is only one input left.
|
|
||||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
|
||||||
while (valuesToReduceRef.size() > 1) {
|
|
||||||
SmallVector<Value> nextValuesToReduce;
|
|
||||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
|
||||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
|
||||||
auto firstValue = valuesToReduceRef[i];
|
|
||||||
auto secondValue = valuesToReduceRef[i + 1];
|
|
||||||
|
|
||||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
|
||||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
|
||||||
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
|
||||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
|
||||||
|
|
||||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
|
||||||
std::swap(firstValue, secondValue);
|
|
||||||
std::swap(firstCompute, secondCompute);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Add a channel before the first computeOp
|
|
||||||
rewriter.setInsertionPoint(firstCompute);
|
|
||||||
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
|
||||||
|
|
||||||
// 2. Add a sendOp after the first value
|
|
||||||
rewriter.setInsertionPointAfterValue(firstValue);
|
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue);
|
|
||||||
|
|
||||||
// 3. Add a receiveOp after the second value
|
|
||||||
rewriter.setInsertionPointAfterValue(secondValue);
|
|
||||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
|
|
||||||
|
|
||||||
// 4. Apply reduction between second value and received value
|
|
||||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
|
||||||
Value reduced = reduce(receivedValue, secondValue);
|
|
||||||
|
|
||||||
nextValuesToReduce.push_back(reduced);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have an odd number of inputs, we need to add the last one to the
|
|
||||||
// newInputs list.
|
|
||||||
if (valuesToReduceRef.size() % 2 == 1)
|
|
||||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
|
||||||
|
|
||||||
// Replace the inputOps list with the new one.
|
|
||||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
|
||||||
|
|
||||||
auto finalValue = valuesToReduceRef[0];
|
|
||||||
|
|
||||||
if (postprocess) {
|
|
||||||
rewriter.setInsertionPointAfterValue(finalValue);
|
|
||||||
finalValue = postprocess(finalValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
return finalValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
bool hasPostProcessPoolingWindow() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp>
|
|
||||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
PoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
ONNXAveragePoolOp poolOp,
|
|
||||||
Value valueToDivide,
|
|
||||||
size_t krn_size,
|
|
||||||
size_t tilesSkippedByPadding) {
|
|
||||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
||||||
|
|
||||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
|
||||||
|
|
||||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
|
||||||
|
|
||||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
|
||||||
// compatible with the current code generation, which assumes constant to be
|
|
||||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
|
||||||
// directly under func.func (i.e. alongside ComputeOps)
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
|
||||||
rewriter.setInsertionPoint(computeOp);
|
|
||||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
|
||||||
scalarTensor,
|
|
||||||
rewriter.getI64IntegerAttr(divisorNumber),
|
|
||||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
|
||||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
|
||||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|
||||||
PoolingBaseConverter(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern<PoolOp>(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
||||||
Value X = adaptor.getX();
|
|
||||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
|
||||||
Value Y = poolOp.getResult();
|
|
||||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
|
||||||
|
|
||||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
|
||||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
|
||||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
|
||||||
|
|
||||||
if (adaptor.getAutoPad() != "NOTSET")
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
|
||||||
|
|
||||||
size_t pad_x, pad_y;
|
|
||||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
|
||||||
if (padUnpackError.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
|
||||||
|
|
||||||
Location loc = poolOp.getLoc();
|
|
||||||
|
|
||||||
size_t input_h = GET_IMAGE_HEIGHT(xShape);
|
|
||||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
|
||||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
|
||||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
|
||||||
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
|
||||||
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
|
||||||
|
|
||||||
// 1: Tile the input tensor
|
|
||||||
// Input tiles need to be indexed by:
|
|
||||||
// a. Channel Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: inputTiles[channelTile][x][y]
|
|
||||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
|
||||||
// Suppose that the input tensor is produced by concatenating the results of
|
|
||||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
|
||||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
|
||||||
|
|
||||||
auto resolveErrorOpt =
|
|
||||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
if (resolveErrorOpt.has_value())
|
|
||||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
|
||||||
|
|
||||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
|
||||||
// can do better.
|
|
||||||
// If some input tiles come from the func.func operands, load
|
|
||||||
// them into a computeOp and yield them
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
|
||||||
Location tileLoc = extractSliceOp.getLoc();
|
|
||||||
|
|
||||||
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
|
|
||||||
extractSliceOp.getResultType(),
|
|
||||||
/* xbarWeights =*/ValueRange(),
|
|
||||||
extractSliceOp.getResult());
|
|
||||||
|
|
||||||
Block* tempComputeOpBlock = new Block();
|
|
||||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
|
||||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
|
||||||
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
|
|
||||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
|
||||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2: Tile the output tensor
|
|
||||||
// Output tiles need to be indexed by:
|
|
||||||
// a. Channel Tile
|
|
||||||
// b. Pixel `x` position
|
|
||||||
// c. Pixel `y` position
|
|
||||||
// For example: outputTiles[channelTile][x][y]
|
|
||||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
|
||||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
|
||||||
|
|
||||||
// List of values to pool for each output pixel
|
|
||||||
SmallVector<Value> valuesToPool;
|
|
||||||
|
|
||||||
// Iterate each output tile
|
|
||||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
|
||||||
// Iterate each output pixel
|
|
||||||
for (size_t outX = 0; outX < output_w; outX++) {
|
|
||||||
for (size_t outY = 0; outY < output_h; outY++) {
|
|
||||||
|
|
||||||
// Each output pixel tile is computed by pooling a window of input
|
|
||||||
// pixel tiles
|
|
||||||
valuesToPool.clear();
|
|
||||||
size_t tilesSkippedByPadding = 0;
|
|
||||||
|
|
||||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
|
||||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
|
||||||
|
|
||||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
|
||||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
|
||||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
|
||||||
tilesSkippedByPadding++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value inputTile = inputTiles[outTile][inX][inY];
|
|
||||||
|
|
||||||
Value valueToPool;
|
|
||||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
|
||||||
|
|
||||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
|
||||||
|
|
||||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
|
||||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
|
||||||
}
|
|
||||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
|
||||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
|
||||||
if (failed(sendOpOpt)) {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"ChannelReceiveOp does not have a matching "
|
|
||||||
"ChannelSendOp.");
|
|
||||||
}
|
|
||||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
|
||||||
|
|
||||||
valueToPool = sendOp.getData();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"Input tile for Pooling is not produced by a "
|
|
||||||
"WeightedComputeOp nor a receiveOp");
|
|
||||||
}
|
|
||||||
|
|
||||||
valuesToPool.push_back(valueToPool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
|
||||||
// assert(computeOpsForPooling.size() != 1 &&
|
|
||||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
|
||||||
// this " "should have been simplified earlier???");
|
|
||||||
|
|
||||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
|
||||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
|
||||||
postProcessFn = [&](const Value prevFinalRes) {
|
|
||||||
return postProcessPoolingWindow(
|
|
||||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
Value reducedWithinCompute = applyReducePatternNew(
|
|
||||||
valuesToPool,
|
|
||||||
rewriter,
|
|
||||||
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
|
|
||||||
nullptr,
|
|
||||||
postProcessFn);
|
|
||||||
|
|
||||||
// Send this value through a channel, and receive it in the
|
|
||||||
// `func.func`. During lowering, we will need to "move it" into the
|
|
||||||
// users computeOps
|
|
||||||
auto computeOpOfReduced =
|
|
||||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
|
||||||
|
|
||||||
// Create a new channel before the computeOp
|
|
||||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
|
||||||
auto reduceChannel =
|
|
||||||
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
|
||||||
|
|
||||||
// Send value through the channel
|
|
||||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
|
|
||||||
|
|
||||||
// Receive after the computeOp
|
|
||||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
|
||||||
auto receivedValue =
|
|
||||||
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
|
|
||||||
|
|
||||||
outputTiles[outTile][outX][outY] = receivedValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
|
||||||
// them!
|
|
||||||
|
|
||||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
|
||||||
|
|
||||||
// Iterate each output tile
|
|
||||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
|
||||||
// Iterate each output pixel
|
|
||||||
for (size_t outX = 0; outX < output_w; outX++) {
|
|
||||||
for (size_t outY = 0; outY < output_h; outY++) {
|
|
||||||
auto outputTile = outputTiles[outTile][outX][outY];
|
|
||||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
|
||||||
if (!outputTileProducer) {
|
|
||||||
return rewriter.notifyMatchFailure(poolOp,
|
|
||||||
"Output tile for Pooling is not produced by a "
|
|
||||||
"WeightedComputeOp.");
|
|
||||||
}
|
|
||||||
|
|
||||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
|
||||||
|
|
||||||
rewriter.replaceOp(poolOp, outputImage);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
|
||||||
ctx);
|
|
||||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
|
|
||||||
|
|
||||||
ReduceMeanConversionPattern(MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
|
|
||||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
|
||||||
|
|
||||||
// Get the input tensor.
|
|
||||||
Value inputTensor = adaptor.getData();
|
|
||||||
auto inputTensorType = cast<RankedTensorType>(inputTensor.getType());
|
|
||||||
|
|
||||||
// This pattern will substitute the ONNXReduceMeanV13Op with a
|
|
||||||
// ONNXAveragePoolOp with the same input tensor and an appropriate kernel
|
|
||||||
// shape and strides.
|
|
||||||
|
|
||||||
// To get the stride and shape of the kernel, we need to read the tensor
|
|
||||||
// shape.
|
|
||||||
int image_height = inputTensorType.getShape()[2];
|
|
||||||
int image_width = inputTensorType.getShape()[3];
|
|
||||||
|
|
||||||
// Define the kernel shape and strides.
|
|
||||||
SmallVector<int64_t> kernelShapeVals = {image_height, image_width};
|
|
||||||
SmallVector<int64_t> stridesVals = {image_height, image_width};
|
|
||||||
SmallVector<int64_t> dilationsVals = {1, 1};
|
|
||||||
|
|
||||||
// Set the pads to 0.
|
|
||||||
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
|
|
||||||
|
|
||||||
// Create the ArrayAttrs
|
|
||||||
auto kernelShape = mlir::ArrayAttr::get(
|
|
||||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
|
|
||||||
return rewriter.getI64IntegerAttr(v);
|
|
||||||
})));
|
|
||||||
|
|
||||||
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
|
|
||||||
llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
|
|
||||||
return rewriter.getI64IntegerAttr(v);
|
|
||||||
})));
|
|
||||||
|
|
||||||
auto dilations = mlir::ArrayAttr::get(
|
|
||||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
|
|
||||||
return rewriter.getI64IntegerAttr(v);
|
|
||||||
})));
|
|
||||||
|
|
||||||
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
|
|
||||||
llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
|
|
||||||
return rewriter.getI64IntegerAttr(v);
|
|
||||||
})));
|
|
||||||
|
|
||||||
// Create the resulting tensor type.
|
|
||||||
auto resultType = RankedTensorType::get(
|
|
||||||
/*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
|
|
||||||
/*elementType=*/inputTensorType.getElementType());
|
|
||||||
|
|
||||||
// Create the ONNXAveragePoolOp.
|
|
||||||
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
|
|
||||||
resultType,
|
|
||||||
inputTensor,
|
|
||||||
/*auto_pad=*/"NOTSET",
|
|
||||||
/*ceil_mode=*/0,
|
|
||||||
/*count_include_pad=*/1,
|
|
||||||
dilations,
|
|
||||||
/*kernel_shape=*/kernelShape,
|
|
||||||
/*pads=*/pads,
|
|
||||||
/*strides=*/strides);
|
|
||||||
|
|
||||||
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
|
|
||||||
rewriter.replaceOp(reduceMean, averagePool.getResult());
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<ReduceMeanConversionPattern>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -8,60 +8,59 @@ include "src/Dialect/ONNX/ONNX.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
def onnxToArithConstantOp : Pat<
|
def onnxToArithConstant : Pat<
|
||||||
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
|
||||||
(Arith_ConstantOp $value)
|
(Arith_ConstantOp $value)
|
||||||
>;
|
>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def matMulAddToGemmPattern : Pat<
|
def IsRank2Result: Constraint<
|
||||||
|
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||||
|
"Result is rank 2">;
|
||||||
|
|
||||||
|
def matMulAddToGemm : Pat<
|
||||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||||
(ONNXGemmOp $A, $B, $C,
|
(ONNXGemmOp $A, $B, $C,
|
||||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||||
)
|
),
|
||||||
|
[(IsRank2Result $matmulres)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def matMulToGemmPattern : Pat<
|
def matMulToGemm : Pat<
|
||||||
(ONNXMatMulOp:$matmulres $A, $B),
|
(ONNXMatMulOp:$matmulres $A, $B),
|
||||||
(
|
(
|
||||||
ONNXGemmOp $A, $B,
|
ONNXGemmOp $A, $B,
|
||||||
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||||
)
|
),
|
||||||
|
[(IsRank2Result $matmulres)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
|
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
|
||||||
// ONNXConvOp with a bias.
|
def convAddToConvWithBiasLeft : Pat<
|
||||||
def convAddToConvWithBiasPatternLeft : Pat<
|
|
||||||
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
|
||||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def convAddToConvWithBiasPatternRight : Pat<
|
def convAddToConvWithBiasRight : Pat<
|
||||||
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
|
||||||
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
|
||||||
>;
|
>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Operation to ignore (i.e. remove)
|
// Operation to ignore (i.e. remove)
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
|
||||||
|
|
||||||
def removeLRNPattern : Pat<
|
def removeLRN : Pat<
|
||||||
(ONNXLRNOp $A, $_, $_, $_, $_),
|
(ONNXLRNOp $A, $_, $_, $_, $_),
|
||||||
(replaceWithOperationOfValue $A)
|
(replaceWithOperationOfValue $A)
|
||||||
>;
|
>;
|
||||||
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
|
|||||||
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
|
||||||
"Two tensors have the same static shape">;
|
"Two tensors have the same static shape">;
|
||||||
|
|
||||||
def removeFlattenSameShapePattern : Pat<
|
def removeFlattenSameShape : Pat<
|
||||||
(ONNXFlattenOp:$flattenOp $A, $axis),
|
(ONNXFlattenOp:$flattenOp $A, $axis),
|
||||||
(replaceWithOperationOfValue $A),
|
(replaceWithOperationOfValue $A),
|
||||||
[(HaveSameStaticShape $flattenOp, $A)]
|
[(HaveSameStaticShape $flattenOp, $A)]
|
||||||
>; // Add closing parenthesis here
|
>;
|
||||||
|
|
||||||
#endif // ONNX_TO_SPATIAL
|
#endif // ONNX_TO_SPATIAL
|
||||||
@@ -1,499 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/Twine.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
|
||||||
assert("Invalid axis" && axis < shape.size());
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
sizes.reserve(shape.size());
|
|
||||||
for (const auto size : shape)
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(size));
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
|
||||||
|
|
||||||
long length = shape[axis];
|
|
||||||
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
|
||||||
SmallVector<Value> slices;
|
|
||||||
slices.reserve(numSlices);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < numSlices; i++) {
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
|
||||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
|
||||||
|
|
||||||
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides);
|
|
||||||
slices.push_back(slice);
|
|
||||||
}
|
|
||||||
|
|
||||||
return slices;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value>
|
|
||||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
|
||||||
assert("Not a vector" && isVectorShape(shape));
|
|
||||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
|
||||||
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
DenseMap<CoreId, SmallVector<Value>>
|
|
||||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
|
||||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
|
||||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
|
||||||
size_t coreId = sliceId / crossbarCountInCore;
|
|
||||||
slicesPerCore[coreId].push_back(slices[sliceId]);
|
|
||||||
}
|
|
||||||
return slicesPerCore;
|
|
||||||
}
|
|
||||||
|
|
||||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
|
||||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
|
||||||
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
|
||||||
|
|
||||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
|
||||||
|
|
||||||
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
|
||||||
size_t numHSlices = hSlices.size();
|
|
||||||
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
|
||||||
Value hSlice = hSlices[hSliceId];
|
|
||||||
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
|
||||||
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
|
||||||
size_t coreId = vSliceId / crossbarCountInCore;
|
|
||||||
Value vSlice = vSlices[vSliceId];
|
|
||||||
tiles[hSliceId][coreId].push_back(vSlice);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tiles;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensor::SplatOp
|
|
||||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
|
||||||
Type elementType = oldType.getElementType();
|
|
||||||
int64_t shape[2] = {1, length};
|
|
||||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
|
||||||
|
|
||||||
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
|
||||||
SmallVector<Value> index(oldType.getRank(), zero);
|
|
||||||
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult();
|
|
||||||
|
|
||||||
return rewriter.create<tensor::SplatOp>(loc, type, elementValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
|
||||||
if (tensors.size() == 1)
|
|
||||||
return tensors[0];
|
|
||||||
|
|
||||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
|
||||||
SmallVector<Value> tensors2;
|
|
||||||
tensors2.reserve(tensors.size() / 2);
|
|
||||||
|
|
||||||
auto* currTensors = &tensors1;
|
|
||||||
auto* nextTensors = &tensors2;
|
|
||||||
while (currTensors->size() > 1) {
|
|
||||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
|
||||||
Value a = (*currTensors)[i];
|
|
||||||
Value b = (*currTensors)[i + 1];
|
|
||||||
rewriter.setInsertionPointAfterValue(b);
|
|
||||||
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
|
|
||||||
nextTensors->push_back(addedValue);
|
|
||||||
}
|
|
||||||
if (currTensors->size() % 2 == 1)
|
|
||||||
nextTensors->push_back(currTensors->back());
|
|
||||||
std::swap(currTensors, nextTensors);
|
|
||||||
nextTensors->clear();
|
|
||||||
}
|
|
||||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
|
||||||
return (*currTensors)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
|
|
||||||
switch (mapOp) {
|
|
||||||
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
|
|
||||||
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input);
|
|
||||||
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2) {
|
|
||||||
if (auto unpackedStrides = valuesArray) {
|
|
||||||
value1 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[0]).getInt();
|
|
||||||
value2 = mlir::cast<IntegerAttr>(unpackedStrides->getValue()[1]).getInt();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
value1 = 1;
|
|
||||||
value2 = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<llvm::Twine>
|
|
||||||
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y) {
|
|
||||||
if (valuesArray.has_value()) {
|
|
||||||
auto pads = mlir::ArrayAttr(*valuesArray);
|
|
||||||
if (pads.size() != 4)
|
|
||||||
return "pads must have 4 elements.";
|
|
||||||
|
|
||||||
pad_x = cast<IntegerAttr>(pads[2]).getInt();
|
|
||||||
pad_y = cast<IntegerAttr>(pads[3]).getInt();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Default padding is 0 unless specified otherwise.
|
|
||||||
// https://onnx.ai/onnx/operators/onnx__Conv.html
|
|
||||||
pad_x = pad_y = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
void tileImageTensorByChannel(Value imageTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
|
|
||||||
size_t tileSize,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
|
|
||||||
|
|
||||||
size_t input_h = GET_IMAGE_HEIGHT(imageShape);
|
|
||||||
size_t input_w = GET_IMAGE_WIDTH(imageShape);
|
|
||||||
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize);
|
|
||||||
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize;
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
Location loc = imageTensor.getLoc();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tileCount; i++) {
|
|
||||||
if (i == tileCount - 1 && tileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(tileRest);
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
offsets[1] = rewriter.getIndexAttr(i * tileSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location& loc,
|
|
||||||
Type outputType) {
|
|
||||||
// Populate the outputTiles for the concat in the given order:
|
|
||||||
// 1. Start top left pixel
|
|
||||||
// 2. Continue on its right pixel till the end of the row
|
|
||||||
// 3. Restart on the next row
|
|
||||||
size_t outputTileCount = outputTiles.size();
|
|
||||||
size_t output_w = outputTiles[0].size();
|
|
||||||
size_t output_h = outputTiles[0][0].size();
|
|
||||||
SmallVector<Value> tilesToConcat;
|
|
||||||
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
|
|
||||||
for (size_t outX = 0; outX < output_h; outX++)
|
|
||||||
for (size_t outY = 0; outY < output_w; outY++)
|
|
||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
|
||||||
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
|
|
||||||
|
|
||||||
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) {
|
|
||||||
|
|
||||||
if (inX < 0) {
|
|
||||||
assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inY < 0) {
|
|
||||||
assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((size_t) inX >= input_w || (size_t) inY >= input_h) {
|
|
||||||
assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds");
|
|
||||||
assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createExtractSliceImg(Value valToSlice,
|
|
||||||
size_t x,
|
|
||||||
size_t y,
|
|
||||||
size_t t,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
|
|
||||||
if (t == channelTileCount - 1 && channelTileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(channelTileRest);
|
|
||||||
|
|
||||||
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value indexImgValue(Value v,
|
|
||||||
size_t x,
|
|
||||||
size_t y,
|
|
||||||
size_t t,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
|
|
||||||
auto newV = rewriter.getRemappedValue(v);
|
|
||||||
if (newV)
|
|
||||||
v = newV;
|
|
||||||
|
|
||||||
if (!v.getDefiningOp())
|
|
||||||
return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
|
|
||||||
if (auto computeOp = v.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
|
||||||
// We found the computeOp that produces the tile we want, just return this
|
|
||||||
// value.
|
|
||||||
// TODO: Should we assert that x,y,t are zero?
|
|
||||||
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero");
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveOp = v.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
|
||||||
// This is a receiveOp, just return its value which will be resolved later
|
|
||||||
assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero");
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto imgConcatOp = v.getDefiningOp<spatial::SpatImgConcatOp>()) {
|
|
||||||
auto imgConcatInput = imgConcatOp.getInputTile(x, y, t);
|
|
||||||
// TODO: Is this correct?
|
|
||||||
// Above we already index exactly the tile we want, so `x=y=t=0` in
|
|
||||||
// recursive call
|
|
||||||
|
|
||||||
return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto tensorConcatOp = v.getDefiningOp<tensor::ConcatOp>()) {
|
|
||||||
// This can be recursive.
|
|
||||||
// First, get the input tensors of the tensor.concatOp
|
|
||||||
// Then, find the input tensor that contains the tile we want
|
|
||||||
// Finally, recursive call asking for the tile
|
|
||||||
auto concatAxis = tensorConcatOp.getDim();
|
|
||||||
assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis");
|
|
||||||
assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis.");
|
|
||||||
SmallVector<size_t, 4> indexDims = {1, t * crossbarSize, x, y};
|
|
||||||
|
|
||||||
// Find the input tensor that contains the tile we want
|
|
||||||
size_t currentTile = 0;
|
|
||||||
for (auto concatInput : tensorConcatOp.getInputs()) {
|
|
||||||
auto concatInputShape = cast<ShapedType>(concatInput.getType());
|
|
||||||
assert(concatInputShape.getRank() == 4 && "Expecting an image tensor");
|
|
||||||
auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis);
|
|
||||||
|
|
||||||
if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) {
|
|
||||||
// This input tensor contains the tile we want
|
|
||||||
indexDims[concatAxis] -= currentTile;
|
|
||||||
if (indexDims[1] % crossbarSize != 0) {
|
|
||||||
assert(ignoreConcatError
|
|
||||||
&& "TODO: Handle non-tile aligned tensor, or set "
|
|
||||||
"--ignore-concat-error=true");
|
|
||||||
}
|
|
||||||
return indexImgValue(concatInput,
|
|
||||||
indexDims[2],
|
|
||||||
indexDims[3],
|
|
||||||
indexDims[1] / crossbarSize,
|
|
||||||
channelTileCount,
|
|
||||||
channelTileRest,
|
|
||||||
input_w,
|
|
||||||
input_h,
|
|
||||||
rewriter);
|
|
||||||
}
|
|
||||||
currentTile += concatInputSizeOnAxis;
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(false
|
|
||||||
&& "Could not find the input tensor that contains the tile "
|
|
||||||
"within tensor.ConcatOp");
|
|
||||||
}
|
|
||||||
|
|
||||||
v.dump();
|
|
||||||
|
|
||||||
assert(false && "indexImgValue: unsupported operation");
|
|
||||||
}
|
|
||||||
|
|
||||||
void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
|
||||||
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Location loc = wholeInputTensor.getLoc();
|
|
||||||
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
if (t == channelTileCount - 1 && channelTileRest != 0)
|
|
||||||
sizes[1] = rewriter.getIndexAttr(channelTileRest);
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
offsets[1] = rewriter.getIndexAttr(t * crossbarSize);
|
|
||||||
offsets[2] = rewriter.getIndexAttr(x);
|
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
|
||||||
|
|
||||||
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
|
|
||||||
for (size_t t = 0; t < channelTileCount; t++) {
|
|
||||||
for (size_t x = 0; x < input_w; x++) {
|
|
||||||
for (size_t y = 0; y < input_h; y++) {
|
|
||||||
inputTiles[t][x][y] =
|
|
||||||
indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
|
|
||||||
const size_t inputTilesCount,
|
|
||||||
const size_t lastInputTileDimension,
|
|
||||||
TensorType inputShape,
|
|
||||||
TensorType outputShape,
|
|
||||||
Value reshapeInput,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
// Only support reshape between an image and a vector (i.e. flatten)
|
|
||||||
if (inputShape.getRank() != 4 || outputShape.getRank() != 2) {
|
|
||||||
return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(),
|
|
||||||
"resolveVecInputTiles only supports reshapes from 4D to 2D tensors");
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* From a 4D tensor <N, C, W, H> to a 2D tensor <N, C*H*W>
|
|
||||||
*/
|
|
||||||
auto N = inputShape.getDimSize(0);
|
|
||||||
auto C = inputShape.getDimSize(1);
|
|
||||||
auto H = inputShape.getDimSize(2);
|
|
||||||
auto W = inputShape.getDimSize(3);
|
|
||||||
assert(N == 1 && "Only support N = 1 for image tensors");
|
|
||||||
|
|
||||||
for (size_t i = 0; i < inputTilesCount; i++) {
|
|
||||||
auto c = (i / (H * W)) % C;
|
|
||||||
// TODO: Is this correct? Or should I invert h and w?
|
|
||||||
auto w = (i / H) % W;
|
|
||||||
auto h = i % H;
|
|
||||||
|
|
||||||
Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter);
|
|
||||||
|
|
||||||
// Assert the shape of the tile, and reshape it
|
|
||||||
auto curTileShape = cast<TensorType>(curTile.getType());
|
|
||||||
assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?");
|
|
||||||
assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?");
|
|
||||||
assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?");
|
|
||||||
assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?");
|
|
||||||
|
|
||||||
// Reshape this pixel tensor into a vector, for compatibility with the
|
|
||||||
// rest
|
|
||||||
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
|
|
||||||
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
|
|
||||||
Value shapeTensor =
|
|
||||||
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
|
||||||
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
|
|
||||||
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
|
|
||||||
|
|
||||||
size_t coreIndex = i / crossbarCountInCore;
|
|
||||||
inputTiles[coreIndex].push_back(reshapedCurTile);
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<size_t, size_t> kernel_get_start_and_end(
|
|
||||||
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) {
|
|
||||||
int64_t firstValid = std::ceil(static_cast<float>(pad) / dilation) * dilation - pad;
|
|
||||||
int64_t start = std::max(firstValid, out_pos * stride - pad);
|
|
||||||
int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad);
|
|
||||||
|
|
||||||
assert(start >= 0 && "Start position must be non-negative.");
|
|
||||||
assert(end >= 0 && "End position must be non-negative.");
|
|
||||||
return std::make_pair(start, end);
|
|
||||||
}
|
|
||||||
|
|
||||||
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) {
|
|
||||||
auto oldSegmentSizes = wcomputeOp->getAttrOfType<DenseI32ArrayAttr>(wcomputeOp.getOperandSegmentSizesAttrName());
|
|
||||||
|
|
||||||
auto newSegmentSizes =
|
|
||||||
DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment});
|
|
||||||
|
|
||||||
wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes);
|
|
||||||
}
|
|
||||||
|
|
||||||
int getResultIndex(Operation* op, Value v) {
|
|
||||||
int resultNumber = -1;
|
|
||||||
for (auto result : op->getResults()) {
|
|
||||||
if (result == v) {
|
|
||||||
resultNumber = result.getResultNumber();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert(resultNumber >= 0 && "Value not found in given operation's results.");
|
|
||||||
|
|
||||||
return resultNumber;
|
|
||||||
}
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -1,263 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
#define DEFINE_MAP_OP(opname) opname,
|
|
||||||
|
|
||||||
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
|
|
||||||
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
|
|
||||||
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
|
|
||||||
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
|
|
||||||
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
|
|
||||||
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
|
|
||||||
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
const StringRef REPLICATION_ATTR_NAME = "replication_factor";
|
|
||||||
|
|
||||||
using HSliceId = size_t;
|
|
||||||
using CoreId = size_t;
|
|
||||||
|
|
||||||
enum class MapOperations {
|
|
||||||
None,
|
|
||||||
ONNXSoftmaxOp,
|
|
||||||
ONNXReluOp,
|
|
||||||
ONNXLeakyReluOp,
|
|
||||||
ONNXExpOp
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr C ceilIntegerDivide(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return 1 + (ac - 1) / bc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
|
||||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
|
||||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
|
||||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
|
||||||
C ac = static_cast<C>(a);
|
|
||||||
C bc = static_cast<C>(b);
|
|
||||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVectorShape(const ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isMatrixShape(const ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isHVectorShape(const ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[0] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
bool isVVectorShape(const ArrayRef<T> shape) {
|
|
||||||
return shape.size() == 2 && shape[1] == 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
T getVectorLength(const ArrayRef<T> shape) {
|
|
||||||
assert(isVectorShape(shape));
|
|
||||||
return shape[0] != 1 ? shape[0] : shape[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); }
|
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
|
|
||||||
|
|
||||||
SmallVector<Value>
|
|
||||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
|
|
||||||
|
|
||||||
DenseMap<CoreId, SmallVector<Value>>
|
|
||||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc);
|
|
||||||
|
|
||||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
|
||||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc);
|
|
||||||
|
|
||||||
tensor::SplatOp
|
|
||||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc);
|
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unpacks an optional pair vector into two size_t values.
|
|
||||||
*
|
|
||||||
* @param valuesArray The optional `mlir::ArrayAttr` containing the pair of
|
|
||||||
* values.
|
|
||||||
* @param value1 The reference to the first `size_t` variable to store the
|
|
||||||
* unpacked value.
|
|
||||||
* @param value2 The reference to the second `size_t` variable to store the
|
|
||||||
* unpacked value.
|
|
||||||
*/
|
|
||||||
void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& value1, size_t& value2);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unpacks the optional pads vector.
|
|
||||||
*
|
|
||||||
* @param valuesArray The optional array attribute containing the values.
|
|
||||||
* @param pad_x The output variable to store the value of pad_x.
|
|
||||||
* @param pad_y The output variable to store the value of pad_y.
|
|
||||||
* @param rewriter The rewriter to notify failure
|
|
||||||
*
|
|
||||||
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
|
|
||||||
*/
|
|
||||||
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Tiles the image tensor by channel.
|
|
||||||
*
|
|
||||||
* This function takes an image tensor and tiles it into smaller tiles based on
|
|
||||||
* the channel dimension. The size of each tile is specified by the tileSize
|
|
||||||
* parameter.
|
|
||||||
*
|
|
||||||
* @param imageTensor The input image tensor (NxCxWxH) to be tiled.
|
|
||||||
* @param tiles The output tiles vector to store the tiled image tensors.
|
|
||||||
* @param tileSize The size of each tile.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating operations.
|
|
||||||
*/
|
|
||||||
void tileImageTensorByChannel(Value imageTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
|
|
||||||
size_t tileSize,
|
|
||||||
ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an ImgConcatOp based on the given tiles.
|
|
||||||
*
|
|
||||||
* This function takes a 3-dimensional vector `outputTiles` representing the
|
|
||||||
* tiles to concatenate. The tiles are indexed by [tile][x][y].
|
|
||||||
*
|
|
||||||
* @param outputTiles The tiles to concatenate.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating the
|
|
||||||
* ImgConcatOp.
|
|
||||||
* @param loc The location of the operation.
|
|
||||||
* @param outputType The type of the output tensor.
|
|
||||||
*
|
|
||||||
* @return The created ImgConcatOp.
|
|
||||||
*/
|
|
||||||
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location& loc,
|
|
||||||
Type outputType);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Verifies if the given input coordinates and padding values are within
|
|
||||||
* the bounds of the input tensor.
|
|
||||||
*
|
|
||||||
* @param input_w The width of the input tensor.
|
|
||||||
* @param input_h The height of the input tensor.
|
|
||||||
* @param inX The X-coordinate of the input.
|
|
||||||
* @param inY The Y-coordinate of the input.
|
|
||||||
* @param pad_x The padding value in the X-direction.
|
|
||||||
* @param pad_y The padding value in the Y-direction.
|
|
||||||
* @return LogicalResult Returns success if the coordinates and padding are
|
|
||||||
* within bounds, failure otherwise.
|
|
||||||
*/
|
|
||||||
LogicalResult
|
|
||||||
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resolves the tiling of the input tensor into smaller tiles.
|
|
||||||
*
|
|
||||||
* This function takes a whole input tensor and tiles it into smaller tiles
|
|
||||||
* using the provided parameters. The resulting tiles are stored in the
|
|
||||||
* `inputTiles` vector.
|
|
||||||
* Input tiles need to be indexed by:
|
|
||||||
* a. Channel Tile
|
|
||||||
* b. Pixel `x` position
|
|
||||||
* c. Pixel `y` position
|
|
||||||
* For example: inputTiles[channelTile][x][y]
|
|
||||||
*
|
|
||||||
* @param wholeInputTensor The whole input tensor to be tiled.
|
|
||||||
* @param inputTiles A vector of vectors of vectors of Values representing the
|
|
||||||
* tiles of the input tensor. The outermost vector represents
|
|
||||||
* the channels, the middle vector represents the rows, and
|
|
||||||
* the innermost vector represents the columns of the tiles.
|
|
||||||
* @param channelTileCount The number of tiles for the `channel` axis.
|
|
||||||
* @param channelTileRest The size of the last channelTile. Set as 0 if tiles
|
|
||||||
* fit exactly
|
|
||||||
* @param input_w The width of the input tensor.
|
|
||||||
* @param input_h The height of the input tensor.
|
|
||||||
* @param rewriter The ConversionPatternRewriter used for creating operations.
|
|
||||||
*
|
|
||||||
* @return std::optional<llvm::Twine> An error message if the input tensor could
|
|
||||||
* not be resolved into tiles.
|
|
||||||
*/
|
|
||||||
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
|
|
||||||
size_t channelTileCount,
|
|
||||||
size_t channelTileRest,
|
|
||||||
size_t input_w,
|
|
||||||
size_t input_h,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes the boundaries of an image kernel application.
|
|
||||||
*
|
|
||||||
* @param out_pos The position of the output element.
|
|
||||||
* @param input_width The width of the input image.
|
|
||||||
* @param krn_width The width of the kernel.
|
|
||||||
* @param stride The stride value.
|
|
||||||
* @param dilation The dilation value.
|
|
||||||
* @param pad The padding value.
|
|
||||||
* @return A pair of size_t values representing the start and end positions of
|
|
||||||
* the kernel application.
|
|
||||||
*/
|
|
||||||
std::pair<size_t, size_t> kernel_get_start_and_end(
|
|
||||||
int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Increment the `operandSegmentSizes` in the WeightedCompute operation
|
|
||||||
* for the `inputs` operand.
|
|
||||||
*
|
|
||||||
* This function increments the size of the `inputs` operand segment in the
|
|
||||||
* `operandSegmentSizes` of the given WeightedCompute operation by the specified
|
|
||||||
* increment. This is necessary when new operands are programmatically added to
|
|
||||||
* the WeightedCompute operation.
|
|
||||||
*
|
|
||||||
* @param wcomputeOp The WeightedCompute operation whose `operandSegmentSizes`
|
|
||||||
* is to be incremented.
|
|
||||||
* @param increment The value by which to increment the `inputs` operand segment
|
|
||||||
* size.
|
|
||||||
*/
|
|
||||||
void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Finds the result index of the given operation that produces the
|
|
||||||
* specified value.
|
|
||||||
*
|
|
||||||
* This function takes an operation and a value, and returns the index of the
|
|
||||||
* result of the operation that corresponds to the given value.
|
|
||||||
*
|
|
||||||
* @param op Operation whose result index is to be found.
|
|
||||||
* @param v The value for which the result index is to be determined.
|
|
||||||
* @return The index of the result of the operation that produces the specified
|
|
||||||
* value.
|
|
||||||
*/
|
|
||||||
int getResultIndex(Operation* op, Value v);
|
|
||||||
|
|
||||||
}; // namespace onnx_mlir
|
|
||||||
@@ -1,81 +1,97 @@
|
|||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
|
||||||
#include "ONNXToSpatialPass.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace spatial {
|
bool haveSameStaticShape(Value lhs, Value rhs);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
|
|
||||||
|
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||||
|
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||||
|
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
||||||
|
|
||||||
|
ONNXToSpatialPass() = default;
|
||||||
|
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
|
|
||||||
|
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
RewritePatternSet mergeActivationPatterns(ctx);
|
RewritePatternSet mergeActivationPatterns(ctx);
|
||||||
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
|
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
|
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||||
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
|
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||||
|
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||||
|
|
||||||
IRRewriter rewriter(moduleOp);
|
IRRewriter rewriter(moduleOp);
|
||||||
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (annotateReplication(funcOp, rewriter).failed()) {
|
if (failed(entryFunc)) {
|
||||||
llvm::dbgs() << "Failed during annotation for replication analysis\n";
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>();
|
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||||
target.addIllegalOp<ONNXMatMulOp>();
|
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||||
|
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||||
target.addIllegalOp<ONNXGemmOp>();
|
target.addIllegalOp<ONNXGemmOp>();
|
||||||
target.addIllegalOp<ONNXConvOp>();
|
target.addIllegalOp<ONNXConvOp>();
|
||||||
target.addIllegalOp<ONNXLRNOp>();
|
|
||||||
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
||||||
target.addIllegalOp<ONNXAveragePoolOp>();
|
target.addIllegalOp<ONNXAveragePoolOp>();
|
||||||
target.addIllegalOp<ONNXConcatOp>();
|
target.addIllegalOp<ONNXReluOp>();
|
||||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||||
|
target.addIllegalOp<ONNXConcatOp>();
|
||||||
|
target.addIllegalOp<ONNXReshapeOp>();
|
||||||
|
target.addIllegalOp<ONNXLRNOp>();
|
||||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
patterns.add<removeLRNPattern>(ctx);
|
patterns.add<removeLRN>(ctx);
|
||||||
|
|
||||||
if (useExperimentalConvImpl) {
|
populateGemmPatterns(patterns, ctx);
|
||||||
populateExperimentalTilingConvOpPattern(patterns, ctx);
|
populateConvPatterns(patterns, ctx);
|
||||||
populateExperimentalPoolingTilingPattern(patterns, ctx);
|
populatePoolPatterns(patterns, ctx);
|
||||||
populateGemmToConvConversionPattern(patterns, ctx);
|
populateReluPatterns(patterns, ctx);
|
||||||
}
|
populateConcatPatterns(patterns, ctx);
|
||||||
else {
|
populateReshapePatterns(patterns, ctx);
|
||||||
populateTilingConvOpPattern(patterns, ctx);
|
|
||||||
populatePoolingTilingPattern(patterns, ctx);
|
|
||||||
populateTilingGemmOpPattern(patterns, ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
|
||||||
populateReduceMeanConversionPattern(patterns, ctx);
|
|
||||||
|
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -85,8 +101,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
// Count the number of compute ops and check they do not exceed the core count
|
// Count the number of compute ops and check they do not exceed the core count
|
||||||
if (coresCount != -1) {
|
if (coresCount != -1) {
|
||||||
int computeOpsCount = 0;
|
int computeOpsCount = 0;
|
||||||
for (auto& op : funcOp.getFunctionBody().front().getOperations())
|
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||||
if (isa<SpatWeightedCompute>(op))
|
if (isa<spatial::SpatWeightedCompute>(op))
|
||||||
computeOpsCount++;
|
computeOpsCount++;
|
||||||
|
|
||||||
if (computeOpsCount > coresCount) {
|
if (computeOpsCount > coresCount) {
|
||||||
@@ -96,29 +112,26 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
|
PassManager cleanupPM(ctx);
|
||||||
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
|
cleanupPM.addPass(createCanonicalizerPass());
|
||||||
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
|
if (failed(cleanupPM.run(moduleOp)))
|
||||||
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
|
annotateWeightsConstants(*entryFunc);
|
||||||
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
|
||||||
|
|
||||||
annotateWeightsConstants(funcOp);
|
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial");
|
dumpModule(moduleOp, "spatial");
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
MLIRContext* ctx = funcOp.getContext();
|
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); });
|
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
||||||
if (isAlwaysWeight)
|
if (isAlwaysWeight)
|
||||||
constantOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
extern bool haveSameStaticShape(Value lhs, Value rhs);
|
|
||||||
|
|
||||||
namespace spatial {
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
|
||||||
|
|
||||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
|
||||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
|
||||||
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
|
||||||
|
|
||||||
ONNXToSpatialPass() = default;
|
|
||||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
|
||||||
|
|
||||||
void runOnOperation() override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace spatial
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
// Experimental patterns.
|
|
||||||
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
22
src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
Normal file
22
src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
265
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp
Normal file
265
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = convOp.getLoc();
|
||||||
|
Value x = convOpAdaptor.getX();
|
||||||
|
Value w = convOpAdaptor.getW();
|
||||||
|
Value b = convOpAdaptor.getB();
|
||||||
|
|
||||||
|
auto xType = cast<RankedTensorType>(x.getType());
|
||||||
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||||
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||||
|
|
||||||
|
// We need to understand what is group
|
||||||
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||||
|
const int64_t wHeight = wType.getDimSize(2);
|
||||||
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
|
||||||
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||||
|
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
||||||
|
|
||||||
|
const auto stridesAttr = convOp.getStrides();
|
||||||
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
||||||
|
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
||||||
|
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
||||||
|
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
||||||
|
|
||||||
|
int64_t padHeightBegin = 0;
|
||||||
|
int64_t padHeightEnd = 0;
|
||||||
|
int64_t padWidthBegin = 0;
|
||||||
|
int64_t padWidthEnd = 0;
|
||||||
|
|
||||||
|
if (padsAttr) {
|
||||||
|
padHeightBegin = getI64(*padsAttr, 0);
|
||||||
|
padWidthBegin = getI64(*padsAttr, 1);
|
||||||
|
padHeightEnd = getI64(*padsAttr, 2);
|
||||||
|
padWidthEnd = getI64(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute padding from auto_pad attribute
|
||||||
|
const auto autoPad = convOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||||
|
const int64_t totalPadW =
|
||||||
|
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padHeightBegin = totalPadH / 2;
|
||||||
|
padHeightEnd = totalPadH - padHeightBegin;
|
||||||
|
padWidthBegin = totalPadW / 2;
|
||||||
|
padWidthEnd = totalPadW - padWidthBegin;
|
||||||
|
}
|
||||||
|
else { // SAME_LOWER
|
||||||
|
padHeightEnd = totalPadH / 2;
|
||||||
|
padHeightBegin = totalPadH - padHeightEnd;
|
||||||
|
padWidthEnd = totalPadW / 2;
|
||||||
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// Gemm output: [numPatches, cOut]
|
||||||
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
|
auto elemType = xType.getElementType();
|
||||||
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||||
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
// Prepare weight matrix W for crossbar storage:
|
||||||
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
wFlatType,
|
||||||
|
w,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
|
||||||
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
Value gemmC;
|
||||||
|
if (hasB)
|
||||||
|
gemmC = b;
|
||||||
|
else
|
||||||
|
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
auto im2colComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
||||||
|
|
||||||
|
auto* im2colBlock = new Block();
|
||||||
|
im2colBlock->addArgument(x.getType(), loc);
|
||||||
|
im2colComputeOp.getBody().push_back(im2colBlock);
|
||||||
|
rewriter.setInsertionPointToStart(im2colBlock);
|
||||||
|
|
||||||
|
Value paddedInput = im2colBlock->getArgument(0);
|
||||||
|
|
||||||
|
// Pad input with zeros if needed:
|
||||||
|
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||||
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
|
rewriter.getIndexAttr(padWidthBegin)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightEnd),
|
||||||
|
rewriter.getIndexAttr(padWidthEnd)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
paddedInput = padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build im2col [numPatches, patchSize]:
|
||||||
|
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||||
|
SmallVector<Value> im2colRows;
|
||||||
|
im2colRows.reserve(numPatches);
|
||||||
|
for (int64_t n = 0; n < batchSize; n++) {
|
||||||
|
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||||
|
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(oh * strideHeight),
|
||||||
|
rewriter.getIndexAttr(ow * strideWidth)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
|
rewriter.getIndexAttr(wHeight),
|
||||||
|
rewriter.getIndexAttr(wWidth)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
|
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||||
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rowType,
|
||||||
|
patch,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
im2colRows.push_back(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate all rows: [numPatches, patchSize]
|
||||||
|
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(im2colComputeOp);
|
||||||
|
|
||||||
|
// Gemm: A @ B + C = im2col @ W^T + b
|
||||||
|
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||||
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmOutType,
|
||||||
|
im2colComputeOp.getResult(0),
|
||||||
|
wTrans,
|
||||||
|
gemmC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false));
|
||||||
|
Value gemmOut = gemmOp.getY();
|
||||||
|
|
||||||
|
auto collectComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
||||||
|
|
||||||
|
auto* collectBlock = new Block();
|
||||||
|
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||||
|
collectComputeOp.getBody().push_back(collectBlock);
|
||||||
|
rewriter.setInsertionPointToStart(collectBlock);
|
||||||
|
|
||||||
|
auto gemmOutArg = collectBlock->getArguments().front();
|
||||||
|
|
||||||
|
// Restore to NCHW layout:
|
||||||
|
// [numPatches, numChannelsOut]
|
||||||
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
|
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||||
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
nhwcType,
|
||||||
|
gemmOutArg,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1, 2},
|
||||||
|
{3}
|
||||||
|
});
|
||||||
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||||
|
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
|
|
||||||
|
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
359
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp
Normal file
359
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Location.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
if (factor == 1.0f)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||||
|
if (!constantOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<APFloat> scaledValues;
|
||||||
|
scaledValues.reserve(denseAttr.getNumElements());
|
||||||
|
APFloat scale(factor);
|
||||||
|
bool hadFailure = false;
|
||||||
|
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||||
|
APFloat scaledValue(originalValue);
|
||||||
|
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||||
|
hadFailure = true;
|
||||||
|
scaledValues.push_back(std::move(scaledValue));
|
||||||
|
}
|
||||||
|
if (hadFailure)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = gemmOp.getLoc();
|
||||||
|
Value a = gemmOpAdaptor.getA();
|
||||||
|
Value b = gemmOpAdaptor.getB();
|
||||||
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
|
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||||
|
|
||||||
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
|
||||||
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
|
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||||
|
|
||||||
|
const int64_t numOutRows = aType.getDimSize(0);
|
||||||
|
|
||||||
|
// Only decompose when there are multiple rows to split
|
||||||
|
if (numOutRows <= 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledB))
|
||||||
|
return failure();
|
||||||
|
b = *scaledB;
|
||||||
|
|
||||||
|
RankedTensorType cType = nullptr;
|
||||||
|
bool cHasNumOutRows = false;
|
||||||
|
if (hasC) {
|
||||||
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledC))
|
||||||
|
return failure();
|
||||||
|
c = *scaledC;
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
c,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
cType = expandedType;
|
||||||
|
}
|
||||||
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<Value> gemvOps;
|
||||||
|
gemvOps.reserve(numOutRows);
|
||||||
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||||
|
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||||
|
|
||||||
|
Value cSlice = c;
|
||||||
|
if (hasC) {
|
||||||
|
if (cHasNumOutRows) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||||
|
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outRowType,
|
||||||
|
aSlice,
|
||||||
|
b,
|
||||||
|
cSlice,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
gemmOp.getTransAAttr(),
|
||||||
|
gemmOp.getTransBAttr());
|
||||||
|
gemvOps.push_back(gemvOp.getY());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||||
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (auto gemvOp : gemvOps)
|
||||||
|
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto blockArgs = concatBlock->getArguments();
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location gemmLoc = gemmOp.getLoc();
|
||||||
|
Value a = gemmOpAdaptor.getA();
|
||||||
|
Value b = gemmOpAdaptor.getB();
|
||||||
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
Value out = gemmOp.getY();
|
||||||
|
|
||||||
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||||
|
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
||||||
|
bool transA = gemmOpAdaptor.getTransA();
|
||||||
|
bool transB = gemmOpAdaptor.getTransB();
|
||||||
|
|
||||||
|
auto aType = cast<RankedTensorType>(a.getType());
|
||||||
|
auto bType = cast<RankedTensorType>(b.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(out.getType());
|
||||||
|
|
||||||
|
RankedTensorType cType = nullptr;
|
||||||
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
|
if (hasC) {
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
gemmLoc,
|
||||||
|
expandedType,
|
||||||
|
c,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
cType = expandedType;
|
||||||
|
}
|
||||||
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||||
|
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||||
|
|
||||||
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||||
|
// Not a gemv
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (transA) {
|
||||||
|
auto aShape = aType.getShape();
|
||||||
|
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||||
|
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
}
|
||||||
|
if (transB) {
|
||||||
|
auto bShape = bType.getShape();
|
||||||
|
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||||
|
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (alpha != 1.0f) {
|
||||||
|
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||||
|
if (failed(scaledB))
|
||||||
|
return failure();
|
||||||
|
b = *scaledB;
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
alpha = 1.0f;
|
||||||
|
}
|
||||||
|
if (hasC && beta != 1.0f) {
|
||||||
|
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||||
|
if (failed(scaledC))
|
||||||
|
return failure();
|
||||||
|
c = *scaledC;
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
beta = 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||||
|
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||||
|
auto bNumVSlices = aNumHSlices;
|
||||||
|
auto bLastVSliceSize = aLastHSliceSize;
|
||||||
|
auto cNumHSlices = bNumHSlices;
|
||||||
|
auto cLastHSliceSize = bLastHSliceSize;
|
||||||
|
auto outNumHSlices = cNumHSlices;
|
||||||
|
auto outLastHSliceSize = cLastHSliceSize;
|
||||||
|
|
||||||
|
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||||
|
|
||||||
|
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||||
|
|
||||||
|
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||||
|
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||||
|
|
||||||
|
SmallVector<Value> cHSlices;
|
||||||
|
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||||
|
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||||
|
if (hasC)
|
||||||
|
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||||
|
|
||||||
|
RankedTensorType outHSliceType =
|
||||||
|
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||||
|
RankedTensorType outLastHSliceType =
|
||||||
|
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<Value> outHSlices;
|
||||||
|
outHSlices.reserve(outNumHSlices);
|
||||||
|
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||||
|
RankedTensorType currOutHSliceType = outHSliceType;
|
||||||
|
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||||
|
currOutHSliceType = outLastHSliceType;
|
||||||
|
|
||||||
|
SmallVector<Value> partialResults;
|
||||||
|
partialResults.reserve(coresPerVSlice);
|
||||||
|
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||||
|
SmallVector<Value> weights;
|
||||||
|
weights.reserve(aHSlices[coreId].size());
|
||||||
|
|
||||||
|
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||||
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
|
auto computeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||||
|
|
||||||
|
auto* computeBlock = new Block();
|
||||||
|
for (auto aHSlice : aHSlices[coreId])
|
||||||
|
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||||
|
computeOp.getBody().push_back(computeBlock);
|
||||||
|
rewriter.setInsertionPointToStart(computeBlock);
|
||||||
|
|
||||||
|
auto computeArgs = computeBlock->getArguments();
|
||||||
|
SmallVector<Value> vmmOutputs;
|
||||||
|
vmmOutputs.reserve(computeArgs.size());
|
||||||
|
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||||
|
vmmOutputs.push_back(
|
||||||
|
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||||
|
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||||
|
|
||||||
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
|
||||||
|
partialResults.push_back(computeOp.getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasC) {
|
||||||
|
Value cHSlice = cHSlices[outSliceId];
|
||||||
|
partialResults.push_back(cHSlice);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reduceComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||||
|
|
||||||
|
auto* reduceBlock = new Block();
|
||||||
|
for (auto partialResult : partialResults)
|
||||||
|
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||||
|
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||||
|
rewriter.setInsertionPointToStart(reduceBlock);
|
||||||
|
|
||||||
|
auto blockArgs = reduceBlock->getArguments();
|
||||||
|
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
||||||
|
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||||
|
|
||||||
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||||
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (auto outHSlice : outHSlices)
|
||||||
|
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto blockArgs = concatBlock->getArguments();
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
119
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp
Normal file
119
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||||
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||||
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||||
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|
|| !outType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t batch = rhsType.getDimSize(0);
|
||||||
|
const int64_t k = rhsType.getDimSize(1);
|
||||||
|
const int64_t n = rhsType.getDimSize(2);
|
||||||
|
const int64_t m = lhsType.getDimSize(0);
|
||||||
|
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||||
|
|| outType.getDimSize(2) != n)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = matmulOp.getLoc();
|
||||||
|
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||||
|
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||||
|
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||||
|
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||||
|
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||||
|
|
||||||
|
Value lhsTransposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
SmallVector<Value> gemmRows;
|
||||||
|
gemmRows.reserve(batch * n);
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
|
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value rhsSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||||
|
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rhsRowType,
|
||||||
|
rhsSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmRowType,
|
||||||
|
rhsRow,
|
||||||
|
lhsTransposed,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false));
|
||||||
|
gemmRows.push_back(gemmOp.getY());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||||
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (Value gemmRow : gemmRows)
|
||||||
|
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||||
|
Value gemmOut = concatComputeOp.getResult(0);
|
||||||
|
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmExpandedType,
|
||||||
|
gemmOut,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
|
||||||
|
rewriter.replaceOp(matmulOp, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <optional>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename ArrayAttrT>
|
||||||
|
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||||
|
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ArrayAttrT>
|
||||||
|
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||||
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||||
|
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||||
|
if (values.size() == 1)
|
||||||
|
return values.front();
|
||||||
|
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(tileType.getRank());
|
||||||
|
for (int64_t dimSize : tileType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||||
|
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ReduceOp>
|
||||||
|
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||||
|
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||||
|
|
||||||
|
Value reduced = windowValues.front();
|
||||||
|
for (Value value : windowValues.drop_front())
|
||||||
|
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||||
|
return reduced;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||||
|
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||||
|
if (divisor == 1)
|
||||||
|
return reducedWindow;
|
||||||
|
|
||||||
|
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||||
|
double scale = 1.0 / static_cast<double>(divisor);
|
||||||
|
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||||
|
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||||
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PoolOp>
|
||||||
|
struct PoolToSpatialCompute;
|
||||||
|
|
||||||
|
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||||
|
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||||
|
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Location loc = poolOp.getLoc();
|
||||||
|
Value x = adaptor.getX();
|
||||||
|
|
||||||
|
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
||||||
|
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
||||||
|
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
||||||
|
if (xType.getRank() != 4 || outType.getRank() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
||||||
|
|
||||||
|
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
||||||
|
if (!kernelAttr || kernelAttr.size() != 2)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t channels = xType.getDimSize(1);
|
||||||
|
const int64_t inputHeight = xType.getDimSize(2);
|
||||||
|
const int64_t inputWidth = xType.getDimSize(3);
|
||||||
|
const int64_t outputHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outputWidth = outType.getDimSize(3);
|
||||||
|
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||||
|
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||||
|
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||||
|
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||||
|
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||||
|
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||||
|
|
||||||
|
int64_t padTop = 0;
|
||||||
|
int64_t padLeft = 0;
|
||||||
|
int64_t padBottom = 0;
|
||||||
|
int64_t padRight = 0;
|
||||||
|
|
||||||
|
if (auto padsAttr = poolOp.getPads()) {
|
||||||
|
if (padsAttr->size() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||||
|
padTop = getI64(*padsAttr, 0);
|
||||||
|
padLeft = getI64(*padsAttr, 1);
|
||||||
|
padBottom = getI64(*padsAttr, 2);
|
||||||
|
padRight = getI64(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
StringRef autoPad = poolOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
||||||
|
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padTop = totalPadH / 2;
|
||||||
|
padBottom = totalPadH - padTop;
|
||||||
|
padLeft = totalPadW / 2;
|
||||||
|
padRight = totalPadW - padLeft;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
padBottom = totalPadH / 2;
|
||||||
|
padTop = totalPadH - padBottom;
|
||||||
|
padRight = totalPadW / 2;
|
||||||
|
padLeft = totalPadW - padRight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(void) padBottom;
|
||||||
|
(void) padRight;
|
||||||
|
|
||||||
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||||
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
||||||
|
|
||||||
|
auto* computeBlock = new Block();
|
||||||
|
computeBlock->addArgument(xType, loc);
|
||||||
|
computeOp.getBody().push_back(computeBlock);
|
||||||
|
rewriter.setInsertionPointToStart(computeBlock);
|
||||||
|
|
||||||
|
Value input = computeBlock->getArgument(0);
|
||||||
|
SmallVector<Value> batchResults;
|
||||||
|
batchResults.reserve(batchSize);
|
||||||
|
|
||||||
|
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||||
|
SmallVector<Value> rows;
|
||||||
|
rows.reserve(outputHeight);
|
||||||
|
|
||||||
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||||
|
SmallVector<Value> rowPixels;
|
||||||
|
rowPixels.reserve(outputWidth);
|
||||||
|
|
||||||
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||||
|
SmallVector<Value> outputChannelTiles;
|
||||||
|
outputChannelTiles.reserve(channelTileCount);
|
||||||
|
|
||||||
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<Value> windowValues;
|
||||||
|
windowValues.reserve(kernelHeight * kernelWidth);
|
||||||
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||||
|
if (inH < 0 || inH >= inputHeight)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||||
|
if (inW < 0 || inW >= inputWidth)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||||
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
|
rewriter.getIndexAttr(inH),
|
||||||
|
rewriter.getIndexAttr(inW)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(tileChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
Value windowValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
||||||
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
|
windowValues.push_back(windowValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (windowValues.empty())
|
||||||
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||||
|
|
||||||
|
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||||
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||||
|
const int64_t divisor =
|
||||||
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||||
|
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputChannelTiles.push_back(reducedWindow);
|
||||||
|
}
|
||||||
|
|
||||||
|
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||||
|
}
|
||||||
|
|
||||||
|
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||||
|
|
||||||
|
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
||||||
|
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
||||||
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||||
|
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
||||||
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||||
|
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
33
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||||
|
Location loc = reluOp.getLoc();
|
||||||
|
Type resultType = reluOp.getResult().getType();
|
||||||
|
constexpr size_t numInputs = 1;
|
||||||
|
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||||
|
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(reluOp, computeOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,16 +1,15 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||||
ONNXConcatToTensorConcat(MLIRContext* ctx)
|
using OpConversionPattern::OpConversionPattern;
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||||
ONNXConcatOpAdaptor adaptor,
|
ONNXConcatOpAdaptor adaptor,
|
||||||
@@ -24,8 +23,8 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<ONNXConcatToTensorConcat>(ctx);
|
patterns.insert<Concat>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
119
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp
Normal file
119
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
SmallVector<ReassociationIndices>& reassociation) {
|
||||||
|
reassociation.clear();
|
||||||
|
|
||||||
|
size_t sourceIdx = 0;
|
||||||
|
size_t resultIdx = 0;
|
||||||
|
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||||
|
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||||
|
int64_t resultProduct = resultShape[resultIdx];
|
||||||
|
|
||||||
|
ReassociationIndices group;
|
||||||
|
group.push_back(sourceIdx);
|
||||||
|
while (sourceProduct != resultProduct) {
|
||||||
|
if (sourceProduct > resultProduct)
|
||||||
|
return false;
|
||||||
|
sourceIdx++;
|
||||||
|
if (sourceIdx >= sourceShape.size())
|
||||||
|
return false;
|
||||||
|
group.push_back(sourceIdx);
|
||||||
|
sourceProduct *= sourceShape[sourceIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
reassociation.push_back(group);
|
||||||
|
sourceIdx++;
|
||||||
|
resultIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
SmallVector<ReassociationIndices>& reassociation) {
|
||||||
|
reassociation.clear();
|
||||||
|
|
||||||
|
size_t sourceIdx = 0;
|
||||||
|
size_t resultIdx = 0;
|
||||||
|
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||||
|
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||||
|
int64_t resultProduct = resultShape[resultIdx];
|
||||||
|
|
||||||
|
ReassociationIndices group;
|
||||||
|
group.push_back(resultIdx);
|
||||||
|
while (resultProduct != sourceProduct) {
|
||||||
|
if (resultProduct > sourceProduct)
|
||||||
|
return false;
|
||||||
|
resultIdx++;
|
||||||
|
if (resultIdx >= resultShape.size())
|
||||||
|
return false;
|
||||||
|
group.push_back(resultIdx);
|
||||||
|
resultProduct *= resultShape[resultIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
reassociation.push_back(group);
|
||||||
|
sourceIdx++;
|
||||||
|
resultIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||||
|
ONNXReshapeOpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (sourceType == resultType) {
|
||||||
|
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
|
if (sourceType.getRank() > resultType.getRank()
|
||||||
|
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceType.getRank() < resultType.getRank()
|
||||||
|
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <typename OpTy, typename OpAdaptorTy>
|
|
||||||
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> {
|
|
||||||
RemoveUnusedHelperOps(MLIRContext* ctx)
|
|
||||||
: OpRewritePattern<OpTy>(ctx) {}
|
|
||||||
|
|
||||||
void initialize() { this->setHasBoundedRewriteRecursion(); }
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
|
|
||||||
if (op.getResult().use_empty()) {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
#include <queue>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Structure that describes the replication of a convolution operation,
|
|
||||||
* along the image height axis.
|
|
||||||
*/
|
|
||||||
struct ConvReplication {
|
|
||||||
ONNXConvOp convOp; // Convolution operation
|
|
||||||
size_t input_w; // Width of the input image
|
|
||||||
size_t replicationFactor; // Replication factor on the image height axis
|
|
||||||
size_t coresNeededPerReplica; // Number of cores needed for each replica
|
|
||||||
|
|
||||||
friend bool operator<(const ConvReplication& a, const ConvReplication& b) {
|
|
||||||
return a.input_w / a.replicationFactor < b.input_w / b.replicationFactor;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConvReplication(ONNXConvOp convOp, size_t input_w, size_t replicationFactor, size_t coresNeededPerReplica)
|
|
||||||
: convOp(convOp),
|
|
||||||
input_w(input_w),
|
|
||||||
replicationFactor(replicationFactor),
|
|
||||||
coresNeededPerReplica(coresNeededPerReplica) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter) {
|
|
||||||
|
|
||||||
if (coresCount == -1) {
|
|
||||||
// No need for annotation, implicitly set replication to 1
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::priority_queue<struct ConvReplication> convOpsReplicationQueue;
|
|
||||||
|
|
||||||
size_t minimumCores = 0;
|
|
||||||
|
|
||||||
for (auto& op : funcOp.getFunctionBody().begin()->getOperations()) {
|
|
||||||
if (auto convOp = dyn_cast<ONNXConvOp>(op)) {
|
|
||||||
// Convolution layer
|
|
||||||
|
|
||||||
Value X = convOp.getX(), W = convOp.getW();
|
|
||||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
|
||||||
ShapedType wShape = mlir::cast<ShapedType>(W.getType());
|
|
||||||
|
|
||||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
|
||||||
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
|
|
||||||
size_t krn_w = GET_KERNEL_WIDTH(wShape);
|
|
||||||
|
|
||||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
|
||||||
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
|
|
||||||
|
|
||||||
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
|
|
||||||
auto neededCores = ceilIntegerDivide(neededXbars, crossbarCountInCore.getValue());
|
|
||||||
|
|
||||||
minimumCores += neededCores;
|
|
||||||
|
|
||||||
convOpsReplicationQueue.emplace(convOp, input_w, 1, neededCores);
|
|
||||||
}
|
|
||||||
else if (auto gemmOp = dyn_cast<ONNXGemmOp>(op)) {
|
|
||||||
// Fully connected layer
|
|
||||||
auto matrixTensorShape = cast<ShapedType>(gemmOp.getB().getType());
|
|
||||||
auto inputSize = matrixTensorShape.getDimSize(0);
|
|
||||||
auto outputSize = matrixTensorShape.getDimSize(1);
|
|
||||||
if (gemmOp.getTransB())
|
|
||||||
std::swap(inputSize, outputSize);
|
|
||||||
|
|
||||||
const size_t inputTilesCount = ceilIntegerDivide(inputSize, crossbarSize.getValue());
|
|
||||||
const size_t outputTilesCount = ceilIntegerDivide(outputSize, crossbarSize.getValue());
|
|
||||||
|
|
||||||
// Each output tile is computed by `coresPerOutputTile` cores. The
|
|
||||||
// entire input is given to each of these cores.
|
|
||||||
const size_t coresPerOutputTile = ceilIntegerDivide(inputTilesCount, crossbarCountInCore.getValue());
|
|
||||||
|
|
||||||
auto neededCores = coresPerOutputTile * outputTilesCount;
|
|
||||||
|
|
||||||
minimumCores += neededCores;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (static_cast<size_t>(coresCount) < minimumCores) {
|
|
||||||
return funcOp->emitError("Not enough cores for this network: ")
|
|
||||||
<< minimumCores << " cores needed, but only " << static_cast<size_t>(coresCount) << " available.";
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t availableCores = static_cast<size_t>(coresCount) - minimumCores;
|
|
||||||
|
|
||||||
// Consume all the elements in the queue
|
|
||||||
while (!convOpsReplicationQueue.empty()) {
|
|
||||||
auto convOpReplication = convOpsReplicationQueue.top();
|
|
||||||
convOpsReplicationQueue.pop();
|
|
||||||
|
|
||||||
// Check if we can replicate this convolution (e.g. we have enough cores)
|
|
||||||
if (availableCores > convOpReplication.coresNeededPerReplica * (convOpReplication.replicationFactor + 1)) {
|
|
||||||
// We can replicate this convolution: increment replicationFactor and put
|
|
||||||
// back in queue
|
|
||||||
availableCores -= convOpReplication.coresNeededPerReplica;
|
|
||||||
convOpReplication.replicationFactor++;
|
|
||||||
|
|
||||||
convOpsReplicationQueue.push(convOpReplication);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Cannot replicate this convolution anymore, annotate the operation
|
|
||||||
// with the replication factor
|
|
||||||
convOpReplication.convOp->setAttr(REPLICATION_ATTR_NAME,
|
|
||||||
rewriter.getI64IntegerAttr(convOpReplication.replicationFactor));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,346 +0,0 @@
|
|||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "SpatialReducer.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
|
|
||||||
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
|
|
||||||
|
|
||||||
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
|
||||||
std::function<Value(const Value&)> processFun,
|
|
||||||
ConversionPatternRewriter& rewriter) {
|
|
||||||
assert(processFun);
|
|
||||||
|
|
||||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
|
||||||
auto resultNum = GET_RES_NUM(computeOpAndResNum);
|
|
||||||
|
|
||||||
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
|
||||||
|
|
||||||
Value result = yieldOp->getOperand(resultNum);
|
|
||||||
rewriter.setInsertionPointAfterValue(result);
|
|
||||||
Value processedResult = processFun(result);
|
|
||||||
if (processedResult == result) {
|
|
||||||
// Sometimes we want processedResult to return the same value but do
|
|
||||||
// something else with it (e.g. in softmax we want to broadcast the value
|
|
||||||
// using a channel). In this case, we can just return the same value.
|
|
||||||
return resultNum;
|
|
||||||
}
|
|
||||||
|
|
||||||
yieldOp->insertOperands(yieldOp->getNumOperands(), processedResult);
|
|
||||||
|
|
||||||
return yieldOp.getNumOperands() - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
|
||||||
std::function<Value(const Value&, const Value&)> reduce,
|
|
||||||
std::function<Value(const Value&)> preprocess,
|
|
||||||
std::function<Value(const Value&)> postprocess) {
|
|
||||||
|
|
||||||
if (preprocess)
|
|
||||||
for (auto& computeOpAndResNum : computeOpsAndResNum)
|
|
||||||
GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
|
|
||||||
|
|
||||||
// It is possible that `computeOpsAndResNum` contains two entries for the same
|
|
||||||
// computeOp. In this case, we need to apply the reduction within-computef
|
|
||||||
|
|
||||||
// Keep a map between a computeOp and the last Value for this reduction
|
|
||||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
|
||||||
for (auto& computeOpAndResNum : computeOpsAndResNum) {
|
|
||||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
|
||||||
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
|
|
||||||
|
|
||||||
auto it = lastValueForCompute.find(computeOp.getOperation());
|
|
||||||
|
|
||||||
if (it != lastValueForCompute.end()) {
|
|
||||||
// If we have already seen this computeOp, apply the reduction
|
|
||||||
// within-compute
|
|
||||||
Value lastWithinComputeValue = it->second;
|
|
||||||
|
|
||||||
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
|
|
||||||
|
|
||||||
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
|
||||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
|
||||||
else
|
|
||||||
rewriter.setInsertionPointAfterValue(valueWithinCompute);
|
|
||||||
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
|
|
||||||
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
|
|
||||||
}
|
|
||||||
|
|
||||||
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, reconstruct from the map the computeOpsAndResNum list
|
|
||||||
computeOpsAndResNum.clear();
|
|
||||||
computeOpsAndResNum.reserve(lastValueForCompute.size());
|
|
||||||
for (auto& entry : lastValueForCompute) {
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
|
|
||||||
auto valueWithinCompute = entry.second;
|
|
||||||
|
|
||||||
// We check if `valueWithinCompute` is already used by the yieldOp, in that
|
|
||||||
// case no need to add it
|
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
|
||||||
bool yieldOpUseFound = false;
|
|
||||||
for (auto& use : valueWithinCompute.getUses()) {
|
|
||||||
if (use.getOwner() == yieldOp.getOperation()) {
|
|
||||||
// If the value is already used by the yieldOp, we can just use it
|
|
||||||
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
|
|
||||||
yieldOpUseFound = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (yieldOpUseFound)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// If this result is not used within a yieldOp, then add it
|
|
||||||
auto resultNum = yieldOp->getNumOperands();
|
|
||||||
yieldOp->insertOperands(resultNum, valueWithinCompute);
|
|
||||||
|
|
||||||
computeOpsAndResNum.push_back({computeOp, resultNum});
|
|
||||||
}
|
|
||||||
|
|
||||||
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
|
|
||||||
|
|
||||||
// Recursive algorithm to reduce the inputs to a single one:
|
|
||||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
|
||||||
// the computeOpsAndResNum list which becomes half the size.
|
|
||||||
// - Repeat until there is only one input left.
|
|
||||||
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
|
|
||||||
while (computeOpsRef.size() > 1) {
|
|
||||||
SmallVector<ComputeAndResNum> nextComputeOps;
|
|
||||||
nextComputeOps.reserve(computeOpsRef.size() / 2);
|
|
||||||
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
|
|
||||||
auto [firstCompute, firstResultNum] = computeOpsRef[i];
|
|
||||||
auto [secondCompute, secondResultNum] = computeOpsRef[i + 1];
|
|
||||||
|
|
||||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
|
||||||
std::swap(firstCompute, secondCompute);
|
|
||||||
std::swap(firstResultNum, secondResultNum);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We do not immediately alter the computeOps results/operands, instead we
|
|
||||||
// do it in a delayed manner, to avoid invalidating the references to the
|
|
||||||
// computeOps (which must be replaced by a cloned ComputeOp when changing
|
|
||||||
// the number of results)
|
|
||||||
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
|
|
||||||
|
|
||||||
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
|
|
||||||
|
|
||||||
// Add a new operand to the block of the second computeOp
|
|
||||||
Block& secondBlock = secondCompute.getBody().front();
|
|
||||||
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
|
|
||||||
|
|
||||||
auto secondComputeWeightsNum =
|
|
||||||
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
|
|
||||||
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
|
|
||||||
|
|
||||||
// Take the "former-result" from the second computeOp
|
|
||||||
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
|
|
||||||
Value formerRes2 = secondYield.getOperand(secondResultNum);
|
|
||||||
|
|
||||||
// Apply reduction operation
|
|
||||||
rewriter.setInsertionPoint(secondYield);
|
|
||||||
Value reduced = reduce(formerRes2, formerRes1);
|
|
||||||
|
|
||||||
// Unfortunately, it is not possible to update the result in place,
|
|
||||||
// because we may have already referenced it by <computeOp, resultNum>
|
|
||||||
// outside of this function, thus replacing it would invalidate the
|
|
||||||
// reference. Therefore, we need to append a new result to the yieldOp,
|
|
||||||
// and then at a later stage update the computeOp accordingly.
|
|
||||||
|
|
||||||
// Add `reduced` to the second yieldOp
|
|
||||||
auto secondYieldOperandNum = secondYield.getNumOperands();
|
|
||||||
secondYield->insertOperands(secondYieldOperandNum, reduced);
|
|
||||||
secondResultNum = secondYieldOperandNum;
|
|
||||||
|
|
||||||
// We should also add an entry for updating the results of the last
|
|
||||||
// operation (the one which never becomes a `firstCompute`): because it is
|
|
||||||
// not tracked by reducerChanges as `fromOp`
|
|
||||||
reducerChanges.push_back(
|
|
||||||
{firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum});
|
|
||||||
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we have an odd number of inputs, we need to add the last one to the
|
|
||||||
// newInputs list.
|
|
||||||
if (computeOpsRef.size() % 2 == 1)
|
|
||||||
nextComputeOps.push_back(computeOpsRef.back());
|
|
||||||
|
|
||||||
// Replace the inputOps list with the new one.
|
|
||||||
computeOpsRef = llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point.");
|
|
||||||
|
|
||||||
auto finalComputeAndResNum = computeOpsRef[0];
|
|
||||||
|
|
||||||
// Force the update of the results of this computeOp, when finalizing
|
|
||||||
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
|
|
||||||
|
|
||||||
if (postprocess)
|
|
||||||
GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
|
|
||||||
|
|
||||||
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum));
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatialReducer::finalizeReduceUpdates() {
|
|
||||||
assert(reducesFinalized == false && "Cannot finalize two times.");
|
|
||||||
|
|
||||||
reducesFinalized = true;
|
|
||||||
|
|
||||||
// First, add the results to the computeOps
|
|
||||||
for (auto& reduceChange : reducerChanges)
|
|
||||||
updateResultsOfCompute(reduceChange.fromOp);
|
|
||||||
|
|
||||||
for (auto& c : computeOpNeedingResUpdate)
|
|
||||||
updateResultsOfCompute(c.getOperation());
|
|
||||||
|
|
||||||
for (auto& reducerChange : this->reducerChanges) {
|
|
||||||
auto fromOp = reducerChange.fromOp;
|
|
||||||
auto toOp = reducerChange.toOp;
|
|
||||||
auto fromOpResNum = reducerChange.fromOpResNum;
|
|
||||||
auto toOpOperandNum = reducerChange.toOpOperandNum;
|
|
||||||
|
|
||||||
auto fromComputeOp = opToReplacedCompute[fromOp];
|
|
||||||
assert(fromComputeOp && "fromOp should have been mapped before!");
|
|
||||||
|
|
||||||
// toComputeOp could be the existing pointer, or we have to remap it with
|
|
||||||
// `opToReplacedCompute`
|
|
||||||
auto toComputeOp = opToReplacedCompute[toOp];
|
|
||||||
if (!toComputeOp)
|
|
||||||
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
|
|
||||||
|
|
||||||
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
|
|
||||||
|
|
||||||
assert(toComputeOp->getNumOperands() == toOpOperandNum
|
|
||||||
&& "toOpOperandNum should be the last operand of toComputeOp, are the "
|
|
||||||
"operations in the right order?");
|
|
||||||
|
|
||||||
// Add the new operand to `toComputeOp`
|
|
||||||
auto fromResult = fromComputeOp.getResult(fromOpResNum);
|
|
||||||
toComputeOp->insertOperands(toOpOperandNum, fromResult);
|
|
||||||
incrementWeightedComputeInputsSegmentSize(toComputeOp, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
|
|
||||||
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
|
|
||||||
|
|
||||||
Operation* opToCast;
|
|
||||||
auto it = opToReplacedCompute.find(opAndResNum.first);
|
|
||||||
if (it != opToReplacedCompute.end())
|
|
||||||
opToCast = it->second;
|
|
||||||
else
|
|
||||||
opToCast = opAndResNum.first;
|
|
||||||
|
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
|
|
||||||
|
|
||||||
return computeOp.getResult(opAndResNum.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
|
||||||
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
|
|
||||||
// If we have already replaced the fromOp, we do not need to do it again
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp);
|
|
||||||
|
|
||||||
auto oldComputeOpNum = oldComputeOp->getNumOperands();
|
|
||||||
|
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
|
|
||||||
|
|
||||||
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
|
|
||||||
// No result was added, just add itself to the map
|
|
||||||
opToReplacedCompute[oldComputeOp.getOperation()] = oldComputeOp;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the results by inspecting its YieldOp
|
|
||||||
auto newResultTypes = yieldOp.getOperandTypes();
|
|
||||||
|
|
||||||
// Create a new ComputeOp with the new result type, but same operands
|
|
||||||
rewriter.setInsertionPoint(oldComputeOp);
|
|
||||||
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
|
||||||
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
|
||||||
|
|
||||||
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
|
|
||||||
|
|
||||||
auto newComputeOpNum = newComputeOp->getNumOperands();
|
|
||||||
|
|
||||||
assert(oldComputeOpNum == newComputeOpNum);
|
|
||||||
|
|
||||||
// Since we replaced the old ComputeOp with a new one, we need to replace
|
|
||||||
// all its results' uses
|
|
||||||
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
|
|
||||||
Value oldResult = oldComputeOp.getResult(i);
|
|
||||||
Value newResult = newComputeOp.getResult(i);
|
|
||||||
|
|
||||||
// Replace the uses, except the uses of the compute ops which got deleted
|
|
||||||
// previously
|
|
||||||
rewriter.replaceAllUsesExcept(oldResult, newResult, oldComputeOpsReplaced);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, erase the old computeOp and update the map
|
|
||||||
opToReplacedCompute[oldComputeOp.getOperation()] = newComputeOp;
|
|
||||||
oldComputeOpsReplaced.insert(oldComputeOp.getOperation());
|
|
||||||
rewriter.setInsertionPoint(oldComputeOp);
|
|
||||||
rewriter.eraseOp(oldComputeOp);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
|
|
||||||
Location& loc,
|
|
||||||
Type outputType) {
|
|
||||||
|
|
||||||
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
|
|
||||||
|
|
||||||
// outputTiles are indexed like this: [channelTile][x][y]
|
|
||||||
auto tilesCount = outputTiles.size();
|
|
||||||
auto width = outputTiles[0].size();
|
|
||||||
auto height = outputTiles[0][0].size();
|
|
||||||
|
|
||||||
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
|
|
||||||
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
|
|
||||||
|
|
||||||
for (size_t t = 0; t < tilesCount; t++)
|
|
||||||
for (size_t x = 0; x < width; x++)
|
|
||||||
for (size_t y = 0; y < height; y++)
|
|
||||||
remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]);
|
|
||||||
|
|
||||||
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
|
|
||||||
}
|
|
||||||
|
|
||||||
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Value biasTile,
|
|
||||||
MapOperations mapOp) {
|
|
||||||
|
|
||||||
std::function<Value(const Value&)> postprocessing = nullptr;
|
|
||||||
|
|
||||||
if (mapOp != MapOperations::None) {
|
|
||||||
postprocessing = [&](const Value a) {
|
|
||||||
Value mapOperand = a;
|
|
||||||
if (biasTile)
|
|
||||||
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
|
|
||||||
return createMapOperation(rewriter, mapOp, mapOperand);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return this->applyReducePattern(
|
|
||||||
computeOps,
|
|
||||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
|
|
||||||
/* preprocess = */ nullptr,
|
|
||||||
postprocessing);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
using ResNum = unsigned int;
|
|
||||||
|
|
||||||
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
|
|
||||||
|
|
||||||
struct SpatialReducerChange {
|
|
||||||
Operation* fromOp;
|
|
||||||
unsigned int fromOpResNum;
|
|
||||||
Operation* toOp;
|
|
||||||
unsigned int toOpOperandNum;
|
|
||||||
};
|
|
||||||
|
|
||||||
using OpAndResNum = std::pair<Operation*, ResNum>;
|
|
||||||
|
|
||||||
class SpatialReducer {
|
|
||||||
|
|
||||||
public:
|
|
||||||
SpatialReducer(ConversionPatternRewriter& rewriter)
|
|
||||||
: rewriter(rewriter) {}
|
|
||||||
|
|
||||||
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
|
||||||
std::function<Value(const Value&, const Value&)> reduce,
|
|
||||||
std::function<Value(const Value&)> preprocess,
|
|
||||||
std::function<Value(const Value&)> postprocess);
|
|
||||||
|
|
||||||
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Value biasTile,
|
|
||||||
MapOperations mapOp);
|
|
||||||
|
|
||||||
void finalizeReduceUpdates();
|
|
||||||
|
|
||||||
~SpatialReducer() {
|
|
||||||
if (!reducesFinalized)
|
|
||||||
finalizeReduceUpdates();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
|
||||||
Location& loc,
|
|
||||||
Type outputType);
|
|
||||||
|
|
||||||
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
|
|
||||||
|
|
||||||
private:
|
|
||||||
[[nodiscard("computeOp result number gets updated")]] ResNum
|
|
||||||
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
|
||||||
std::function<Value(const Value&)> processFun,
|
|
||||||
ConversionPatternRewriter& rewriter);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Update the results of a ComputeOp.
|
|
||||||
*
|
|
||||||
* This function updates the results of a ComputeOp by taking a look at the
|
|
||||||
operands of its yieldOp.
|
|
||||||
* If the ComputeOp was replaced, it updates `opToReplacedCompute` with the
|
|
||||||
replaced ComputeOp.
|
|
||||||
*
|
|
||||||
* @param computeOp The ComputeOp to update the results of.
|
|
||||||
*/
|
|
||||||
void updateResultsOfCompute(Operation* computeOp);
|
|
||||||
|
|
||||||
ConversionPatternRewriter& rewriter;
|
|
||||||
bool reducesFinalized = false;
|
|
||||||
|
|
||||||
// List of changes to be applied after the reduction is finalized
|
|
||||||
SmallVector<SpatialReducerChange, 4> reducerChanges;
|
|
||||||
// List of computeOps that need to be replaced with new results
|
|
||||||
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
|
|
||||||
|
|
||||||
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
|
|
||||||
|
|
||||||
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
|
|
||||||
: weights(std::move(weights)) {}
|
|
||||||
|
|
||||||
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
|
|
||||||
|
|
||||||
TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
|
||||||
assert(!weights.empty() && "No weights to extract.");
|
|
||||||
|
|
||||||
auto it = weights.begin();
|
|
||||||
SmallVector<Value>& values = it->second.begin()->second;
|
|
||||||
|
|
||||||
long inputTile = it->first;
|
|
||||||
long outputTile = it->second.begin()->first;
|
|
||||||
|
|
||||||
size_t n = std::min(amount, values.size());
|
|
||||||
crossbarsUsed += n;
|
|
||||||
|
|
||||||
SmallVector<Value> result;
|
|
||||||
result.assign(values.begin(), values.begin() + n);
|
|
||||||
|
|
||||||
if (n < values.size()) {
|
|
||||||
values.erase(values.begin(), values.begin() + n);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
it->second.erase(outputTile);
|
|
||||||
if (it->second.empty())
|
|
||||||
weights.erase(inputTile);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {inputTile, outputTile, crossbarsUsed - n, result};
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
|
|
||||||
crossbarsUsed = 0;
|
|
||||||
SmallVector<TaggedWeights> result;
|
|
||||||
size_t remaining = n;
|
|
||||||
|
|
||||||
while (remaining > 0 && !weights.empty()) {
|
|
||||||
auto group = popGroup(remaining);
|
|
||||||
result.push_back(group);
|
|
||||||
remaining -= group.weights.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A helper struct to store a group of weights.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
struct TaggedWeights {
|
|
||||||
long inputTile;
|
|
||||||
long outputTile;
|
|
||||||
size_t startingCrossbarIndex;
|
|
||||||
SmallVector<Value> weights;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A helper class to subdivide weights into groups.
|
|
||||||
*
|
|
||||||
* Weights are stored as a map of maps of SmallVectors. The outer map is indexed
|
|
||||||
* by input tile, the inner map is indexed by output tile, and the SmallVector
|
|
||||||
* contains the weights for the filter. This class allows us to extract groups
|
|
||||||
* of weights from the map until we've extracted a certain number of elements,
|
|
||||||
* namely as many as we need to fill a compute unit.
|
|
||||||
*/
|
|
||||||
class WeightSubdivider {
|
|
||||||
private:
|
|
||||||
map<long, map<long, SmallVector<Value>>> weights;
|
|
||||||
size_t crossbarsUsed = 0;
|
|
||||||
|
|
||||||
TaggedWeights popGroup(size_t amount);
|
|
||||||
|
|
||||||
public:
|
|
||||||
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights);
|
|
||||||
|
|
||||||
bool isEmpty() const;
|
|
||||||
SmallVector<TaggedWeights> popGroups(size_t n);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,14 +1,17 @@
|
|||||||
add_onnx_mlir_rewriter(SpatialToGraphviz)
|
add_onnx_mlir_rewriter(SpatialToGraphviz)
|
||||||
|
|
||||||
add_onnx_mlir_library(OMSpatialToGraphviz
|
add_pim_library(OMSpatialToGraphviz
|
||||||
SpatialToGraphviz.cpp
|
SpatialToGraphviz.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRTosaDialect
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
${PIM_INCLUDE_PATH}
|
${PIM_GENERATED_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,8 +10,9 @@
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Format.h"
|
#include "llvm/Support/Format.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
||||||
@@ -199,12 +200,12 @@ private:
|
|||||||
void SpatialToGraphvizPass::runOnOperation() {
|
void SpatialToGraphvizPass::runOnOperation() {
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
|
|
||||||
// Get the first OP, must be a FuncOp
|
auto entryFunc = getPimEntryFunc(module);
|
||||||
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
|
if (failed(entryFunc)) {
|
||||||
if (!func) {
|
|
||||||
module->emitError("No FuncOp found in the begin of module");
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
func::FuncOp func = *entryFunc;
|
||||||
|
|
||||||
os << "digraph G {\n"
|
os << "digraph G {\n"
|
||||||
<< "\tnode [style=filled,color=white];\n";
|
<< "\tnode [style=filled,color=white];\n";
|
||||||
@@ -222,9 +223,6 @@ void SpatialToGraphvizPass::runOnOperation() {
|
|||||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
drawConcatOpSubgraph(concatOp, concatNum++);
|
drawConcatOpSubgraph(concatOp, concatNum++);
|
||||||
}
|
}
|
||||||
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
|
|
||||||
drawConcatOpSubgraph(imgConcatOp, concatNum++);
|
|
||||||
}
|
|
||||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
||||||
if (producerOp) {
|
if (producerOp) {
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
|
|
||||||
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|
||||||
add_public_tablegen_target(SpatialToPIMIncGen)
|
|
||||||
|
|
||||||
add_onnx_mlir_library(OMSpatialToPIM
|
|
||||||
SpatialToPIMPass.hpp
|
|
||||||
SpatialToPIMPass.cpp
|
|
||||||
SpatialToPIMCommon.cpp
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
SpatialToPIMIncGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
OMCompilerOptions
|
|
||||||
OMPIMCommon
|
|
||||||
SpatialOps
|
|
||||||
PimOps
|
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
|
||||||
${PIM_INCLUDE_PATH}
|
|
||||||
)
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
#ifndef SPATIAL_TO_PIM
|
|
||||||
#define SPATIAL_TO_PIM
|
|
||||||
|
|
||||||
#ifndef OP_BASE
|
|
||||||
include "mlir/IR/PatternBase.td"
|
|
||||||
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|
||||||
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
|
|
||||||
#endif // OP_BASE
|
|
||||||
|
|
||||||
def spatToPimVMMOp : Pat<
|
|
||||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
|
||||||
(PimVMMOp $weightIndex, $vector,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatToPimMVMOp : Pat<
|
|
||||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
|
||||||
(PimMVMOp $weightIndex, $vector,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatToPimVAddOp : Pat<
|
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
|
||||||
(PimVAddOp $a, $b,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
|
||||||
* its static tensor input.
|
|
||||||
*
|
|
||||||
* The static offsets represent the starting position of the slice in each
|
|
||||||
* dimension, while the static tensor input gives its dimension size.
|
|
||||||
*
|
|
||||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
|
||||||
* calculated.
|
|
||||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
|
||||||
* \return The actual offset of the ExtractSliceOp.
|
|
||||||
*/
|
|
||||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
size_t rangeLength(const iterator_range<T> range) {
|
|
||||||
return std::distance(range.begin(), range.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Retrieves the earliest operation that uses the given value within the value's
|
|
||||||
* block.
|
|
||||||
*
|
|
||||||
* @param value The value for which to find the earliest user operation.
|
|
||||||
* @return The earliest user operation that uses the given value within the
|
|
||||||
* current block.
|
|
||||||
*/
|
|
||||||
Operation* getEarliestUserWithinBlock(Value value);
|
|
||||||
|
|
||||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
|
|
||||||
|
|
||||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
|
|
||||||
|
|
||||||
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
|
|
||||||
const ArrayRef<int64_t> offsets,
|
|
||||||
const ArrayRef<int64_t> sizes,
|
|
||||||
const ArrayRef<int64_t> strides) {
|
|
||||||
// Check that all strides are 1
|
|
||||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// Check offsets from right to left:
|
|
||||||
// The first offset_n at position n different from 0:
|
|
||||||
// - limits all sizes to the left to 1
|
|
||||||
// - limits size_n to dimension_n - offset_n
|
|
||||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
||||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstNonZeroOffset = std::find_if(
|
|
||||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return offset != 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
||||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
||||||
if (size > dimension - offset)
|
|
||||||
return false;
|
|
||||||
++firstNonZeroOffset;
|
|
||||||
|
|
||||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check sizes from right to left:
|
|
||||||
// The first size_n at position n different from shape_n limits all sizes to the left to 1
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstDifferentSize != sizesAndShape.end()) {
|
|
||||||
++firstDifferentSize;
|
|
||||||
|
|
||||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, _] = sizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
|
|
||||||
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
|
||||||
|
|
||||||
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
|
|
||||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
||||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
||||||
|
|
||||||
SpatialToPIMPass() = default;
|
|
||||||
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
|
|
||||||
|
|
||||||
void runOnOperation() final;
|
|
||||||
|
|
||||||
private:
|
|
||||||
SmallVector<Value> outputTensors;
|
|
||||||
size_t coreId = 0;
|
|
||||||
SmallVector<Operation*> operationsToRemove;
|
|
||||||
|
|
||||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
|
||||||
void addReceiveOps(Value& channelSourceOp,
|
|
||||||
spatial::SpatChannelNewOp& channel,
|
|
||||||
Type& channelTensorType,
|
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter);
|
|
||||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
|
||||||
unsigned int argIndex,
|
|
||||||
spatial::SpatChannelNewOp& channel,
|
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
|
||||||
IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); }
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
namespace spatial {
|
|
||||||
|
|
||||||
// TODO: Add here eventual patterns
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
23
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
23
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
|
||||||
|
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||||
|
add_public_tablegen_target(SpatialToPimIncGen)
|
||||||
|
|
||||||
|
add_pim_library(OMSpatialToPim
|
||||||
|
SpatialToPimPass.cpp
|
||||||
|
Common.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
SpatialToPimIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRTosaDialect
|
||||||
|
OMCompilerOptions
|
||||||
|
OMPimCommon
|
||||||
|
SpatialOps
|
||||||
|
PimOps
|
||||||
|
|
||||||
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
|
${PIM_GENERATED_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
@@ -5,9 +5,10 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "SpatialToPIMCommon.hpp"
|
#include "Common.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
|||||||
return returnValue;
|
return returnValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
assert(!users.empty());
|
assert(!users.empty());
|
||||||
@@ -66,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
|
|||||||
return earliestUser;
|
return earliestUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
|
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||||
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
|
auto operandsAndUses =
|
||||||
|
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
|
||||||
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
||||||
});
|
});
|
||||||
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
||||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
auto resultType = result.getType();
|
auto resultType = result.getType();
|
||||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||||
|
|
||||||
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||||
auto validOperands =
|
auto validOperands =
|
||||||
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
|
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
||||||
auto bestOperand = validOperands.begin();
|
auto bestOperand = validOperands.begin();
|
||||||
|
|
||||||
if (bestOperand != validOperands.end())
|
if (bestOperand != validOperands.end())
|
||||||
@@ -90,8 +92,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
|
|||||||
|
|
||||||
auto resultShapedType = cast<ShapedType>(resultType);
|
auto resultShapedType = cast<ShapedType>(resultType);
|
||||||
rewriter.setInsertionPoint(operation);
|
rewriter.setInsertionPoint(operation);
|
||||||
return rewriter.create<tensor::EmptyOp>(
|
return tensor::EmptyOp::create(
|
||||||
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
48
src/PIM/Conversion/SpatialToPim/Common.hpp
Normal file
48
src/PIM/Conversion/SpatialToPim/Common.hpp
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||||
|
* its static tensor input.
|
||||||
|
*
|
||||||
|
* The static offsets represent the starting position of the slice in each
|
||||||
|
* dimension, while the static tensor input gives its dimension size.
|
||||||
|
*
|
||||||
|
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
||||||
|
* calculated.
|
||||||
|
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
||||||
|
* \return The actual offset of the ExtractSliceOp.
|
||||||
|
*/
|
||||||
|
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||||
|
return std::distance(range.begin(), range.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the earliest operation that uses the given value within the value's
|
||||||
|
* block.
|
||||||
|
*
|
||||||
|
* @param value The value for which to find the earliest user operation.
|
||||||
|
* @return The earliest user operation that uses the given value within the
|
||||||
|
* current block.
|
||||||
|
*/
|
||||||
|
mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
||||||
|
|
||||||
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
|
inline mlir::tensor::EmptyOp
|
||||||
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
|
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
54
src/PIM/Conversion/SpatialToPim/SpatialToPim.td
Normal file
54
src/PIM/Conversion/SpatialToPim/SpatialToPim.td
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
#ifndef SPATIAL_TO_PIM
|
||||||
|
#define SPATIAL_TO_PIM
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "mlir/IR/PatternBase.td"
|
||||||
|
include "mlir/Dialect/Tensor/IR/TensorOps.td"
|
||||||
|
include "src/Dialect/ONNX/ONNX.td"
|
||||||
|
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||||
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def onnxToPimTranspose : Pat<
|
||||||
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
|
(PimTransposeOp $data, $perms,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVMM : Pat<
|
||||||
|
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
||||||
|
(PimVMMOp $weightIndex, $vector,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimMVM : Pat<
|
||||||
|
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
||||||
|
(PimMVMOp $weightIndex, $vector,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVAdd : Pat<
|
||||||
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVAddOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVMul : Pat<
|
||||||
|
(SpatVMulOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVMulOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVMax : Pat<
|
||||||
|
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVMaxOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatToPimVRelu : Pat<
|
||||||
|
(SpatReluOp:$srcOpRes $input),
|
||||||
|
(PimVReluOp $input,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
|
#endif // SPATIAL_TO_PIM
|
||||||
@@ -1,8 +1,11 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
@@ -16,21 +19,122 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "SpatialToPIMPass.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir;
|
||||||
|
using namespace pim;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace pim {
|
namespace {
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnOperation() {
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||||
|
|
||||||
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||||
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||||
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||||
|
|
||||||
|
SpatialToPimPass() = default;
|
||||||
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
SmallVector<Value> outputTensors;
|
||||||
|
size_t coreId = 0;
|
||||||
|
SmallVector<Operation*> operationsToRemove;
|
||||||
|
|
||||||
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||||
|
void
|
||||||
|
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||||
|
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
|
spatial::SpatChannelNewOp& channel,
|
||||||
|
bool useBroadcastOp,
|
||||||
|
IRRewriter& rewriter);
|
||||||
|
void markOpToRemove(Operation* op);
|
||||||
|
|
||||||
|
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
static bool isChannelUseChainOp(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp,
|
||||||
|
tensor::CollapseShapeOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CastOp,
|
||||||
|
tosa::ReshapeOp,
|
||||||
|
pim::PimTransposeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
if (mapping.lookupOrNull(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t countComputeLeafUsers(Value value) {
|
||||||
|
size_t leafUserCount = 0;
|
||||||
|
|
||||||
|
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
|
Operation* owner = use.getOwner();
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||||
|
leafUserCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isChannelUseChainOp(owner))
|
||||||
|
llvm_unreachable("Channel use chain contains unsupported op");
|
||||||
|
|
||||||
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
|
self(owner->getResult(0), self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
walkUses(value, walkUses);
|
||||||
|
return leafUserCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 1;
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
|
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
populateWithGenerated(patterns);
|
populateWithGenerated(patterns);
|
||||||
@@ -40,15 +144,21 @@ void SpatialToPIMPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (!funcOp)
|
if (failed(entryFunc)) {
|
||||||
llvm_unreachable("No FuncOp found in the begin of module");
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
|
||||||
addResultBuffer(returnOp, rewriter);
|
addResultBuffer(returnOp, rewriter);
|
||||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||||
operationsToRemove.push_back(receiveOp);
|
operationsToRemove.push_back(receiveOp);
|
||||||
@@ -74,10 +184,10 @@ void SpatialToPIMPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
auto& block = computeOp.getRegion().front();
|
auto& block = computeOp.getRegion().front();
|
||||||
@@ -87,35 +197,67 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||||
|
|
||||||
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||||
// If this result has no uses, then just skip it
|
|
||||||
if (result.use_empty())
|
if (result.use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||||
|
|
||||||
/*
|
|
||||||
* Here we assume that ReturnOp are only reachable by the following patterns:
|
|
||||||
*
|
|
||||||
* 1)
|
|
||||||
* %0 = spat.compute([...])
|
|
||||||
* [%0 has one user, which is a ConcatOp]
|
|
||||||
* %1 = tensor.concat(%0)
|
|
||||||
* [%1 has one user, which is a ReturnOp]
|
|
||||||
* return %1
|
|
||||||
*
|
|
||||||
* 2)
|
|
||||||
* %0 = spat.compute([...])
|
|
||||||
* [%0 has one user, which is a ReturnOp]
|
|
||||||
* return %0
|
|
||||||
*
|
|
||||||
* If the IR is like 2), then we can store the tensor to the output global memory location
|
|
||||||
*/
|
|
||||||
auto resultUses = result.getUses();
|
auto resultUses = result.getUses();
|
||||||
auto numResultUses = rangeLength(resultUses);
|
auto numResultUses = rangeLength(resultUses);
|
||||||
if (numResultUses == 1) {
|
if (numResultUses == 1) {
|
||||||
OpOperand& resultUse = *resultUses.begin();
|
OpOperand& resultUse = *resultUses.begin();
|
||||||
Operation* resultUser = resultUse.getOwner();
|
Operation* resultUser = resultUse.getOwner();
|
||||||
|
|
||||||
|
if (isChannelUseChainOp(resultUser)) {
|
||||||
|
SmallVector<Operation*> returnChain;
|
||||||
|
Value chainedValue = result;
|
||||||
|
Operation* chainUser = resultUser;
|
||||||
|
|
||||||
|
while (isChannelUseChainOp(chainUser)) {
|
||||||
|
returnChain.push_back(chainUser);
|
||||||
|
auto chainUses = chainUser->getResult(0).getUses();
|
||||||
|
if (rangeLength(chainUses) != 1)
|
||||||
|
break;
|
||||||
|
chainedValue = chainUser->getResult(0);
|
||||||
|
chainUser = chainUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<func::ReturnOp>(chainUser)) {
|
||||||
|
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(result, yieldValue);
|
||||||
|
|
||||||
|
Value storedValue = yieldValue;
|
||||||
|
for (Operation* op : returnChain) {
|
||||||
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
storedValue = clonedOp->getResult(0);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
markOpToRemove(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||||
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||||
|
if (auto storedOp = storedValue.getDefiningOp())
|
||||||
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputTensor.getType(),
|
||||||
|
outputTensor,
|
||||||
|
storedValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
@@ -125,7 +267,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
rewriter.create<PimMemCopyDevToHostOp>(loc,
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
yieldValue,
|
yieldValue,
|
||||||
@@ -135,7 +278,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
|
if (isa<tensor::ConcatOp>(resultUser)) {
|
||||||
auto concatOp = resultUser;
|
auto concatOp = resultUser;
|
||||||
auto concatValue = concatOp->getResult(0);
|
auto concatValue = concatOp->getResult(0);
|
||||||
auto concatUses = concatValue.getUses();
|
auto concatUses = concatValue.getUses();
|
||||||
@@ -156,7 +299,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[concatIndexInReturn];
|
Value outputTensor = outputTensors[concatIndexInReturn];
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
rewriter.create<PimMemCopyDevToHostOp>(
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -175,23 +318,20 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// 1. Create a new ChannelOp
|
// 1. Create a new ChannelOp
|
||||||
rewriter.setInsertionPoint(computeOp);
|
rewriter.setInsertionPoint(computeOp);
|
||||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||||
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||||
|
|
||||||
// 2. Receive value through the channel
|
// 2. Receive value through the channel. Broadcast is needed whenever the
|
||||||
// If this result is used by more than one user, then use a "Broadcast"
|
// value eventually reaches more than one compute consumer, even through a
|
||||||
// channel operation. However, there is a special case: we have a single
|
// chain of view-like ops.
|
||||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
||||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
||||||
// will detect this case and update `useBroadcastOp` accordingly.
|
|
||||||
bool useBroadcastOp = (numResultUses > 1);
|
|
||||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
// 3. Send the value through the channel
|
// 3. Send the value through the channel
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
|
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||||
else
|
else
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
|
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use `HaltOp` instead of `YieldOp`
|
// Use `HaltOp` instead of `YieldOp`
|
||||||
@@ -200,17 +340,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
|
|
||||||
// Replace `spat.compute` with `pim.core`
|
// Replace `spat.compute` with `pim.core`
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
block.eraseArguments(0, block.getNumArguments());
|
block.eraseArguments(0, block.getNumArguments());
|
||||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||||
Block* tempComputeBlock = new Block();
|
Block* tempComputeBlock = new Block();
|
||||||
computeOp.getBody().push_back(tempComputeBlock);
|
computeOp.getBody().push_back(tempComputeBlock);
|
||||||
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||||
rewriter.create<PimHaltOp>(computeOp.getLoc());
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||||
auto* definingOp = value.getDefiningOp();
|
auto* definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
@@ -228,8 +368,8 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
};
|
};
|
||||||
|
|
||||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||||
auto outTensorOperand = vmmOp.getOutBuf();
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
||||||
auto resultTensor = vmmOp.getOutRes();
|
auto resultTensor = vmmOp.getOutput();
|
||||||
auto outShape = getTensorShape(outTensorOperand);
|
auto outShape = getTensorShape(outTensorOperand);
|
||||||
assert(isHVectorShape(outShape));
|
assert(isHVectorShape(outShape));
|
||||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||||
@@ -247,24 +387,31 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||||
rewriter.setInsertionPointAfter(vmmOp);
|
rewriter.setInsertionPointAfter(vmmOp);
|
||||||
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||||
for (auto returnValue : returnOp->getOperands()) {
|
for (auto returnValue : returnOp->getOperands()) {
|
||||||
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||||
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||||
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||||
|
outputTensors.push_back(returnValue);
|
||||||
|
}
|
||||||
|
else {
|
||||||
auto newOutputTensor =
|
auto newOutputTensor =
|
||||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
||||||
outputTensors.push_back(newOutputTensor);
|
outputTensors.push_back(newOutputTensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||||
@@ -273,9 +420,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
||||||
|
|
||||||
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||||
|
|
||||||
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter,
|
||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
@@ -295,16 +443,19 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||||
|
|
||||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||||
|
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||||
|
|
||||||
Block& block = funcOp.getBody().front();
|
Block& block = funcOp.getBody().front();
|
||||||
rewriter.setInsertionPoint(&block.front());
|
rewriter.setInsertionPoint(&block.front());
|
||||||
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
auto toTensorOp =
|
||||||
|
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||||
inputTensors.push_back(toTensorOp);
|
inputTensors.push_back(toTensorOp);
|
||||||
|
|
||||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||||
funcOp.eraseArgument(i);
|
if (failed(funcOp.eraseArgument(i)))
|
||||||
|
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||||
@@ -318,6 +469,9 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||||
|
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||||
|
continue;
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||||
@@ -351,12 +505,15 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
for (auto sliceOp : sliceOpsToRemove)
|
for (auto sliceOp : sliceOpsToRemove)
|
||||||
if (sliceOp->getUses().empty())
|
if (sliceOp->getUses().empty())
|
||||||
rewriter.eraseOp(sliceOp);
|
rewriter.eraseOp(sliceOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto& computeBlock = computeOp.getRegion().front();
|
auto& computeBlock = computeOp.getRegion().front();
|
||||||
@@ -369,70 +526,74 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||||
Value receivedValue;
|
Value receivedValue;
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
receivedValue =
|
||||||
|
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
else
|
else
|
||||||
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
receivedValue =
|
||||||
|
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
|
|
||||||
blockArg.replaceAllUsesWith(receivedValue);
|
Value replacementValue = receivedValue;
|
||||||
|
if (consumerValue != channelSourceOp) {
|
||||||
|
SmallVector<Operation*> clonedChain;
|
||||||
|
Value currentValue = consumerValue;
|
||||||
|
while (currentValue != channelSourceOp) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || !isChannelUseChainOp(definingOp))
|
||||||
|
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
||||||
|
|
||||||
|
clonedChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
IRMapping mapping;
|
||||||
|
mapping.map(channelSourceOp, receivedValue);
|
||||||
|
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||||
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
markOpToRemove(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(replacementValue.getType() == blockArg.getType()
|
||||||
|
&& "Replayed channel use chain must match block argument type");
|
||||||
|
blockArg.replaceAllUsesWith(replacementValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
bool useBroadcastOp,
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto sourceOpUses = channelSourceOp.getUses();
|
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
Operation* owner = use.getOwner();
|
||||||
if (useBroadcastOp == false) {
|
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
|
||||||
assert(rangeLength(sourceOpUses) == 1);
|
|
||||||
|
|
||||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
|
||||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
|
||||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
|
||||||
if (reshapeOpUsesCount > 1)
|
|
||||||
useBroadcastOp = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& resultUse : sourceOpUses) {
|
|
||||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
|
||||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
|
||||||
|
|
||||||
if (computeUser) {
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
replaceBlockArgumentWithRecvOp(
|
||||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!computeUser) {
|
if (!isChannelUseChainOp(owner))
|
||||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
||||||
if (!reshapeOp) {
|
|
||||||
resultUse.getOwner()->dump();
|
markOpToRemove(owner);
|
||||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
|
self(owner->getResult(0), self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tensorType now becomes the one of the reshapeOp
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||||
channelTensorType = reshapeOp.getResult().getType();
|
if (!llvm::is_contained(operationsToRemove, op))
|
||||||
|
operationsToRemove.push_back(op);
|
||||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
|
||||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
|
||||||
|
|
||||||
if (!computeUser)
|
|
||||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
|
||||||
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
|
||||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the reshapeOp, so that the sourceOp has no users
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
operationsToRemove.push_back(reshapeOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
||||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
|
|
||||||
@@ -441,9 +602,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
rewriter.modifyOpInPlace(returnOp,
|
rewriter.modifyOpInPlace(returnOp,
|
||||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||||
|
|
||||||
// If the operand is a concatenation operation and the returnOp was the only
|
if (isa<tensor::ConcatOp>(returnOperand)) {
|
||||||
// user of the returnOperand, we can safely remove it
|
|
||||||
if (isAConcatOp(returnOperand)) {
|
|
||||||
auto returnOperandUses = it.value().getUses();
|
auto returnOperandUses = it.value().getUses();
|
||||||
if (rangeLength(returnOperandUses) == 0)
|
if (rangeLength(returnOperandUses) == 0)
|
||||||
rewriter.eraseOp(returnOperand);
|
rewriter.eraseOp(returnOperand);
|
||||||
@@ -451,7 +610,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||||
|
|
||||||
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
||||||
|
|
||||||
@@ -461,25 +620,20 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
|
|
||||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||||
|
|
||||||
auto tensorType = receiveOp.getType();
|
|
||||||
Value receiveRes = receiveOp.getResult();
|
Value receiveRes = receiveOp.getResult();
|
||||||
|
|
||||||
// Check if the receiveOp value has more than one user
|
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
||||||
auto receiveUses = receiveRes.getUses();
|
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
||||||
auto receiveUsesCount = rangeLength(receiveUses);
|
|
||||||
assert(receiveUsesCount > 0);
|
|
||||||
bool useBroadcastOp = receiveUsesCount > 1;
|
|
||||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
if (useBroadcastOp) {
|
if (useBroadcastOp) {
|
||||||
// When receiving, we actually noticed that the value has more than one
|
// When receiving, we actually noticed that the value has more than one
|
||||||
// user. This means that we need to get the replace the original SendOp with
|
// user. This means that we need to get the replace the original SendOp with
|
||||||
// a BroadcastSendOp
|
// a BroadcastSendOp
|
||||||
rewriter.setInsertionPoint(sendOp);
|
rewriter.setInsertionPoint(sendOp);
|
||||||
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
|
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getInput());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pim
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,2 +1,2 @@
|
|||||||
add_subdirectory(PIM)
|
add_subdirectory(Pim)
|
||||||
add_subdirectory(Spatial)
|
add_subdirectory(Spatial)
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
#ifndef PIM_DIALECT_H
|
|
||||||
#define PIM_DIALECT_H
|
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
|
||||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
||||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
|
||||||
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
|
|
||||||
|
|
||||||
def PimDialect : Dialect {
|
|
||||||
let name = "pim";
|
|
||||||
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
|
|
||||||
let cppNamespace = "::onnx_mlir::pim";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Base class for Pim dialect operations. This operation inherits from the
|
|
||||||
// base `Op` class in OpBase.td, and provides:
|
|
||||||
// * The parent dialect of the operation.
|
|
||||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
|
||||||
// * A list of traits for the operation.
|
|
||||||
class PimOp<string mnemonic, list<Trait> traits = []> :
|
|
||||||
Op<PimDialect, mnemonic, traits>;
|
|
||||||
|
|
||||||
def PimTensor :
|
|
||||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Communication operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimSendOp: PimOp<"send", []> {
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $src,
|
|
||||||
I32Attr: $size,
|
|
||||||
I32Attr: $targetCoreId
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $src `)` attr-dict `:` type($src) `->` `(` `)`
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $dst,
|
|
||||||
I32Attr: $size,
|
|
||||||
I32Attr: $srcCoreId
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $out
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $dst `)` attr-dict `:` type($dst) `->` type($out)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Core operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimCoreOp: PimOp<"core", [SingleBlock]> {
|
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<PimTensor>:$weights,
|
|
||||||
I32Attr: $coreId
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Memory Operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimConstantOp: PimOp<"constant", []> {
|
|
||||||
let description = [{
|
|
||||||
Allocate a constant value in global memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyAttr: $value,
|
|
||||||
BoolAttr: $shouldAllocate
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $out
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyHostToDevOp: PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Copy a memory region from host memory into device memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $deviceDst,
|
|
||||||
PimTensor: $hostSrc,
|
|
||||||
I32Attr: $deviceDstOffset,
|
|
||||||
I32Attr: $hostSrcOffset,
|
|
||||||
I32Attr: $size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $deviceDstOut
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDeviceDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $deviceDst `,` $hostSrc `)` attr-dict `:` `(` type($deviceDst) `,` type($hostSrc) `)` `->` type($deviceDstOut)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Copy a memory region from device memory into host memory
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $hostDst,
|
|
||||||
PimTensor: $deviceSrc,
|
|
||||||
I32Attr: $hostDstOffset,
|
|
||||||
I32Attr: $deviceSrcOffset,
|
|
||||||
I32Attr: $size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $hostDstOut
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getHostDstMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $hostDst `,` $deviceSrc `)` attr-dict `:` `(` type($hostDst) `,` type($deviceSrc) `)` `->` type($hostDstOut)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Core.Compute operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Vector-matrix multiplication: c = a * b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
PimTensor: $vectorInput,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Matrix-vector multiplication: c = a * b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
PimTensor: $vectorInput,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise addition: c = a + b
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutBufMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise max: c = max(a, b)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $b,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Apply filters to a tensor
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I64ArrayAttr: $weightIndices,
|
|
||||||
I64ArrayAttr: $xKernelPositions,
|
|
||||||
I64ArrayAttr: $yKernelPositions,
|
|
||||||
PimTensor: $input,
|
|
||||||
PimTensor: $outBuf,
|
|
||||||
PimTensor: $accumBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` `input` `=` $input `,` `outBuf` `=` $outBuf `,` `accumBuf` `=` $accumBuf `)` attr-dict `:`
|
|
||||||
type($input) `,` type($outBuf) `,` type($accumBuf) `->` type($outRes)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Sum all elements into a single one
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $dividend,
|
|
||||||
PimTensor: $divisor,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise ReLU: c = max(a, 0)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
|
||||||
let description = [{
|
|
||||||
Element-wise exp: c = exp(a)
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor: $a,
|
|
||||||
PimTensor: $outBuf
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor: $outRes
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimHaltOp: PimOp<"halt", [Terminator]> {
|
|
||||||
let description = [{
|
|
||||||
Halts the execution of the core
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
attr-dict
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // PIM_DIALECT_H
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
#include "mlir/IR/IntegerSet.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/OpImplementation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
void PimDialect::initialize() {
|
|
||||||
addOperations<
|
|
||||||
#define GET_OP_LIST
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
|
||||||
|
|
||||||
>();
|
|
||||||
}
|
|
||||||
|
|
||||||
#define POPULATE_DEPENDENCIES(OP_NAME) \
|
|
||||||
void OP_NAME::populateDependencies(bufferization::RegisterDependenciesFn registerDependenciesFn) { \
|
|
||||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
|
||||||
}
|
|
||||||
|
|
||||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimSumOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
|
||||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// TableGen'd op method definitions
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace bufferization;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
|
||||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
|
||||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
|
||||||
|
|
||||||
auto deviceDstOpt = getBuffer(rewriter, deviceDst, options, state);
|
|
||||||
if (failed(deviceDstOpt))
|
|
||||||
return failure();
|
|
||||||
auto deviceDstMemRef = *deviceDstOpt;
|
|
||||||
|
|
||||||
auto hostSrcOpt = getBuffer(rewriter, hostSrc, options, state);
|
|
||||||
if (failed(hostSrcOpt))
|
|
||||||
return failure();
|
|
||||||
auto hostSrcMemRef = *hostSrcOpt;
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
|
||||||
memCopyHostToDevOp,
|
|
||||||
deviceDstMemRef.getType(),
|
|
||||||
deviceDstMemRef,
|
|
||||||
hostSrcMemRef,
|
|
||||||
memCopyHostToDevOp.getDeviceDstOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getHostSrcOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getSizeAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MemCopyDevToHostOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
|
||||||
|
|
||||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
|
||||||
auto globalDstOpt = getBuffer(rewriter, globalDst, options, state);
|
|
||||||
if (failed(globalDstOpt))
|
|
||||||
return failure();
|
|
||||||
auto globalDstMemRef = *globalDstOpt;
|
|
||||||
|
|
||||||
auto localSrc = memCopyDevToHostOp.getDeviceSrc();
|
|
||||||
auto localSrcOpt = getBuffer(rewriter, localSrc, options, state);
|
|
||||||
if (failed(localSrcOpt))
|
|
||||||
return failure();
|
|
||||||
auto localSrcMemRef = *localSrcOpt;
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
|
||||||
memCopyDevToHostOp,
|
|
||||||
globalDstMemRef.getType(),
|
|
||||||
globalDstMemRef,
|
|
||||||
localSrcMemRef,
|
|
||||||
memCopyDevToHostOp.getHostDstOffsetAttr(),
|
|
||||||
memCopyDevToHostOp.getDeviceSrcOffsetAttr(),
|
|
||||||
memCopyDevToHostOp.getSizeAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
|
|
||||||
auto vmmOp = cast<PimVMMOp>(op);
|
|
||||||
Value readVal = uRead->get();
|
|
||||||
Value writeVal = uWrite->get();
|
|
||||||
if (writeVal != vmmOp.getOutBuf())
|
|
||||||
return false;
|
|
||||||
if (readVal == vmmOp.getVectorInput())
|
|
||||||
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto vmmOp = cast<PimVMMOp>(op);
|
|
||||||
|
|
||||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
|
||||||
if (failed(vectorInputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, vmmOp.getOutBuf(), options, state);
|
|
||||||
if (failed(outBufOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
|
||||||
rewriter, op, outBufOpt->getType(), vmmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBufferizeInterface, PimMVMOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto mvmOp = cast<PimMVMOp>(op);
|
|
||||||
|
|
||||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
|
||||||
if (failed(vectorInputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, mvmOp.getOutBuf(), options, state);
|
|
||||||
if (failed(outBufOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
|
||||||
rewriter, op, outBufOpt->getType(), mvmOp.getWeightIndexAttr(), *vectorInputOpt, *outBufOpt);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
|
||||||
const AnalysisState& state,
|
|
||||||
ArrayRef<OpOperand*> opOperands) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto vaddOp = cast<PimVAddOp>(op);
|
|
||||||
|
|
||||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
|
||||||
if (failed(aOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
|
||||||
if (failed(bOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
|
||||||
if (failed(outBufOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
|
||||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
|
||||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
|
||||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/DialectRegistry.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace pim {
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
|
||||||
|
|
||||||
} // namespace pim
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
add_onnx_mlir_dialect(Pim pim)
|
add_onnx_mlir_dialect(Pim pim)
|
||||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||||
|
|
||||||
|
add_subdirectory(Transforms/Bufferization)
|
||||||
|
|
||||||
add_onnx_mlir_library(PimOps
|
add_pim_library(PimOps
|
||||||
|
PimOps.hpp
|
||||||
PimOps.cpp
|
PimOps.cpp
|
||||||
Transforms/PimBufferizableOpInterface.cpp
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
OMPimIncGen
|
OMPimIncGen
|
||||||
458
src/PIM/Dialect/Pim/Pim.td
Normal file
458
src/PIM/Dialect/Pim/Pim.td
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
#ifndef PIM_DIALECT_H
|
||||||
|
#define PIM_DIALECT_H
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||||
|
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
|
||||||
|
|
||||||
|
def PimDialect : Dialect {
|
||||||
|
let name = "pim";
|
||||||
|
let summary = "A low-level dialect for the PIM coprocessors on ReRAM crossbars";
|
||||||
|
let cppNamespace = "::onnx_mlir::pim";
|
||||||
|
}
|
||||||
|
|
||||||
|
class PimOp<string mnemonic, list<Trait> traits = []> :
|
||||||
|
Op<PimDialect, mnemonic, traits>;
|
||||||
|
|
||||||
|
def PimTensor :
|
||||||
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Execution
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||||
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<PimTensor>:$weights,
|
||||||
|
I32Attr:$coreId
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||||
|
let summary = "Halt execution of the core";
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Communication
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PimSendOp : PimOp<"send", []> {
|
||||||
|
let summary = "Send a tensor to another core";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
I32Attr:$size,
|
||||||
|
I32Attr:$targetCoreId
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Receive a tensor from another core";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$outputBuffer,
|
||||||
|
I32Attr:$size,
|
||||||
|
I32Attr:$sourceCoreId
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region from host memory into device memory";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$deviceTarget,
|
||||||
|
PimTensor:$hostSource,
|
||||||
|
I32Attr:$deviceTargetOffset,
|
||||||
|
I32Attr:$hostSourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getDeviceTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region from device memory into host memory";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$hostTarget,
|
||||||
|
PimTensor:$deviceSource,
|
||||||
|
I32Attr:$hostTargetOffset,
|
||||||
|
I32Attr:$deviceSourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getHostTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Copy a memory region within the same memory space";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$target,
|
||||||
|
PimTensor:$source,
|
||||||
|
I32Attr:$targetOffset,
|
||||||
|
I32Attr:$sourceOffset,
|
||||||
|
I32Attr:$size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getTargetMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Math
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def PimTransposeOp : PimOp<"transpose", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Transpose a matrix";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
I64ArrayAttr:$permutation,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Vector-matrix multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Matrix-vector multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise addition: c = a + b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVSubOp : PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise subtraction: c = a - b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMulOp : PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise multiplication: c = a * b";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVMaxOp : PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise max: c = max(a, b)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVVDMulOp : PimOp<"vvdmul", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Dot product: c = dot(a, b)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$lhs,
|
||||||
|
PimTensor:$rhs,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $lhs `,` $rhs `,` $outputBuffer `)` attr-dict `:` `(` type($lhs) `,` type($rhs) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVAvgOp : PimOp<"vavg", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Average all elements into a single value";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVReluOp : PimOp<"vrelu", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise ReLU: c = max(a, 0)";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVTanhOp : PimOp<"vtanh", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise tanh activation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Element-wise sigmoid activation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
PimTensor:$outputBuffer
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // PIM_DIALECT_H
|
||||||
37
src/PIM/Dialect/Pim/PimOps.cpp
Normal file
37
src/PIM/Dialect/Pim/PimOps.cpp
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/IntegerSet.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace pim {
|
||||||
|
|
||||||
|
void PimDialect::initialize() {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||||
|
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace pim
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TableGen'd op method definitions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
/// Include the auto-generated header files containing the declarations
|
/// Include the auto-generated header files containing the declarations
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"
|
||||||
23
src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt
Normal file
23
src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||||
|
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||||
|
add_public_tablegen_target(PimBufferizationIncGen)
|
||||||
|
|
||||||
|
add_pim_library(OMPimBufferization
|
||||||
|
PimBufferizationPass.cpp
|
||||||
|
OpBufferizationInterfaces.hpp
|
||||||
|
OpBufferizationInterfaces.cpp
|
||||||
|
Common.hpp
|
||||||
|
Common.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
PimBufferizationIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
OMPimCommon
|
||||||
|
PimOps
|
||||||
|
|
||||||
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
|
${PIM_GENERATED_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
9
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp
Normal file
9
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||||
|
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||||
|
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||||
|
return builder.getI32IntegerAttr(sizeInBytes);
|
||||||
|
}
|
||||||
11
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp
Normal file
11
src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace pim {
|
||||||
|
|
||||||
|
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
|
||||||
|
|
||||||
|
} // namespace pim
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "OpBufferizationInterfaces.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace bufferization;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace pim {
|
||||||
|
|
||||||
|
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
|
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||||
|
return memrefValue;
|
||||||
|
|
||||||
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
|
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||||
|
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||||
|
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
return PimMemCopyOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
contiguousType,
|
||||||
|
contiguousBuffer,
|
||||||
|
memrefValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MemCopyHostToDevOpInterface
|
||||||
|
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||||
|
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
|
||||||
|
auto hostSource = memCopyHostToDevOp.getHostSource();
|
||||||
|
|
||||||
|
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
|
||||||
|
if (failed(deviceTargetOpt))
|
||||||
|
return failure();
|
||||||
|
auto deviceTargetMemRef = *deviceTargetOpt;
|
||||||
|
|
||||||
|
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
|
||||||
|
if (failed(hostSourceOpt))
|
||||||
|
return failure();
|
||||||
|
auto hostSourceMemRef = *hostSourceOpt;
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||||
|
memCopyHostToDevOp,
|
||||||
|
deviceTargetMemRef.getType(),
|
||||||
|
deviceTargetMemRef,
|
||||||
|
hostSourceMemRef,
|
||||||
|
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||||
|
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||||
|
memCopyHostToDevOp.getSizeAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MemCopyDevToHostOpInterface
|
||||||
|
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||||
|
|
||||||
|
auto hostTarget = memCopyDevToHostOp.getHostTarget();
|
||||||
|
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
|
||||||
|
if (failed(hostTargetOpt))
|
||||||
|
return failure();
|
||||||
|
auto hostTargetMemRef = *hostTargetOpt;
|
||||||
|
|
||||||
|
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
|
||||||
|
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
|
||||||
|
if (failed(deviceSourceOpt))
|
||||||
|
return failure();
|
||||||
|
auto deviceSourceMemRef = *deviceSourceOpt;
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||||
|
memCopyDevToHostOp,
|
||||||
|
hostTargetMemRef.getType(),
|
||||||
|
hostTargetMemRef,
|
||||||
|
deviceSourceMemRef,
|
||||||
|
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||||
|
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||||
|
memCopyDevToHostOp.getSizeAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto transposeOp = cast<PimTransposeOp>(op);
|
||||||
|
|
||||||
|
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface, PimVMMOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNotConflicting(Operation* op, OpOperand* uRead, OpOperand* uWrite, const AnalysisState& state) const {
|
||||||
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
|
Value readVal = uRead->get();
|
||||||
|
Value writeVal = uWrite->get();
|
||||||
|
if (writeVal != vmmOp.getOutputBuffer())
|
||||||
|
return false;
|
||||||
|
if (readVal == vmmOp.getInput())
|
||||||
|
if (state.areEquivalentBufferizedValues(readVal, writeVal))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
|
|
||||||
|
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto mvmOp = cast<PimMVMOp>(op);
|
||||||
|
|
||||||
|
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||||
|
const AnalysisState& state,
|
||||||
|
ArrayRef<OpOperand*> opOperands) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto binaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
|
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
|
||||||
|
if (failed(lhsOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
|
||||||
|
if (failed(rhsOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||||
|
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<OpTy>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bufferizesToElementwiseAccess(Operation* tablegen_opaque_val,
|
||||||
|
const AnalysisState& state,
|
||||||
|
ArrayRef<OpOperand*> opOperands) const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto unaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
|
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
|
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||||
|
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||||
|
|
||||||
|
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||||
|
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||||
|
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||||
|
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||||
|
|
||||||
|
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||||
|
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||||
|
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||||
|
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace pim
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/DialectRegistry.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace pim {
|
||||||
|
|
||||||
|
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
|
||||||
|
|
||||||
|
} // namespace pim
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
#ifndef PIM_BUFFERIZATION
|
||||||
|
#define PIM_BUFFERIZATION
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "mlir/IR/PatternBase.td"
|
||||||
|
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||||
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def memrefCopyToPimMemCopyOp : Pat<
|
||||||
|
(CopyOp $src, $dst),
|
||||||
|
(PimMemCopyOp $dst, $src,
|
||||||
|
ConstantAttr<I32Attr, "0">,
|
||||||
|
ConstantAttr<I32Attr, "0">,
|
||||||
|
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
||||||
|
(returnType $dst))
|
||||||
|
>;
|
||||||
|
|
||||||
|
#endif // PIM_BUFFERIZATION
|
||||||
@@ -5,23 +5,43 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "Common/PimCommon.hpp"
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "PimBufferizationPass.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using namespace onnx_mlir;
|
||||||
|
using namespace pim;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace pim {
|
namespace {
|
||||||
|
|
||||||
|
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||||
|
|
||||||
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
|
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||||
|
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
|
||||||
|
|
||||||
|
PimBufferizationPass() = default;
|
||||||
|
PimBufferizationPass(const PimBufferizationPass& pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
|
|
||||||
// Do One-Shot-Bufferization
|
// One-Shot-Bufferization
|
||||||
bufferization::OneShotBufferizationOptions options;
|
bufferization::OneShotBufferizationOptions options;
|
||||||
options.allowUnknownOps = true;
|
options.allowUnknownOps = true;
|
||||||
bufferization::BufferizationState state;
|
bufferization::BufferizationState state;
|
||||||
@@ -30,7 +50,19 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove toTensor operations
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
ConversionTarget target(*ctx);
|
||||||
|
target.addLegalDialect<PimDialect>();
|
||||||
|
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateWithGenerated(patterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove toTensor operations: leave memrefs instead
|
||||||
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
|
||||||
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
|
||||||
toTensorOp.erase();
|
toTensorOp.erase();
|
||||||
@@ -57,23 +89,22 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim_buf");
|
dumpModule(moduleOp, "pim1_buff");
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
MLIRContext* ctx = funcOp.getContext();
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
|
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||||
&& !getGlobalOp->getUsers().empty();
|
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||||
if (isAlwaysWeight) {
|
if (isAlwaysWeight) {
|
||||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
markWeightAlways(getGlobalOp);
|
||||||
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
|
markWeightAlways(globalMemrefOp);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pim
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,15 +1,22 @@
|
|||||||
add_onnx_mlir_dialect(Spatial spat)
|
add_onnx_mlir_dialect(Spatial spat)
|
||||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||||
|
|
||||||
|
add_pim_library(SpatialOps
|
||||||
add_onnx_mlir_library(SpatialOps
|
|
||||||
SpatialOps.cpp
|
SpatialOps.cpp
|
||||||
Transforms/SpatialBufferizableOpInterface.cpp
|
Transforms/SpatialBufferizableOpInterface.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
OMONNXIncGen
|
||||||
OMSpatialIncGen
|
OMSpatialIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRBufferizationDialect
|
||||||
|
MLIRBufferizationTransforms
|
||||||
OMMlirDialects
|
OMMlirDialects
|
||||||
|
OMONNXOps
|
||||||
|
OMPimCompilerOptions
|
||||||
|
PimOps
|
||||||
)
|
)
|
||||||
@@ -28,8 +28,12 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
|||||||
let summary = "Virtual channel type";
|
let summary = "Virtual channel type";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Execution
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||||
let summary = "Compute operation, with constant weights already attached";
|
let summary = "Compute region with attached constant weights";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<SpatTensor>:$weights,
|
Variadic<SpatTensor>:$weights,
|
||||||
@@ -50,6 +54,8 @@ def SpatWeightedCompute: SpatOp<"compute", [SingleBlock, AttrSizedOperandSegment
|
|||||||
}
|
}
|
||||||
|
|
||||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||||
|
let summary = "Yield results from a compute region";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<SpatTensor>:$outputs
|
Variadic<SpatTensor>:$outputs
|
||||||
);
|
);
|
||||||
@@ -60,12 +66,14 @@ def SpatYieldOp: SpatOp<"yield", [Terminator]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Data movement operations
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||||
|
let summary = "Create a new virtual channel";
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatChannelType:$new_channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
@@ -80,107 +88,73 @@ def SpatChannelNewOp: SpatOp<"channel_new", []> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||||
|
let summary = "Send a tensor through a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel,
|
SpatChannelType:$channel,
|
||||||
SpatTensor: $data
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$data `to` $channel attr-dict `:` `(` type($data) `->` type($channel) `)`
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||||
|
let summary = "Receive a tensor from a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor: $data
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$channel attr-dict `:` `(` type($channel) `->` type($data) `)`
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||||
|
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel,
|
SpatChannelType:$channel,
|
||||||
SpatTensor: $data
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||||
|
let summary = "Receive a tensor from a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatChannelType:$channel
|
SpatChannelType:$channel
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor: $data
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
let assemblyFormat = [{
|
||||||
// Math operations
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SpatConstantOp: SpatOp<"constant", []> {
|
|
||||||
let description = [{
|
|
||||||
"Constant value, should be used for weights and biases"
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
AnyAttr: $value,
|
|
||||||
BoolAttr: $shouldAllocate
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor: $out
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Math
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
||||||
let summary = "Vector-matrix-Multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$weightIndex,
|
I32Attr:$weightIndex,
|
||||||
SpatTensor:$vector
|
SpatTensor:$input
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
|
||||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatWeightedMVMOp: SpatOp<"Wmvm", []> {
|
|
||||||
let summary = "Matrix-vector multiplication within a WeightedCompute operation. The matrix is found in the weights of the WeightedCompute operation, indexed by the weightIndex attribute.";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr: $weightIndex,
|
|
||||||
SpatTensor:$vector
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
// TODO: Verifier that checks it is within a WeightedCompute operation,
|
|
||||||
// that the weightIndex is valid, and that the matrix is of the right size.
|
|
||||||
let hasVerifier = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def SpatVAddOp: SpatOp<"vadd", []> {
|
|
||||||
let summary = "Element-wise add between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor: $a,
|
|
||||||
SpatTensor: $b
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -190,65 +164,67 @@ def SpatVAddOp: SpatOp<"vadd", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
||||||
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
I32Attr:$weightIndex,
|
||||||
|
SpatTensor:$input
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||||
|
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$lhs,
|
||||||
|
SpatTensor:$rhs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||||
let summary = "Element-wise multiplication between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor: $a,
|
SpatTensor:$lhs,
|
||||||
SpatTensor: $b
|
SpatTensor:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
//let hasVerifier = 1;
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVDivOp: SpatOp<"vdiv", []> {
|
|
||||||
let summary = "Element-wise division between tensors a and b. Tensor b must have the same size of tensor b or be a 1x1";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor:$a,
|
|
||||||
SpatTensor:$b
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
//let hasVerifier = 1;
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$a `,` $b attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//TODO: remove
|
|
||||||
def SpatVSDivOp: SpatOp<"vsdiv", []> {
|
|
||||||
|
|
||||||
let summary = "Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor:$dividend,
|
|
||||||
SpatTensor:$divisor
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatSumOp : SpatOp<"sum", []> {
|
def SpatSumOp : SpatOp<"sum", []> {
|
||||||
let summary = "Sum all the elements in the input tensors into a single scalar wrapped in tensor for convenience";
|
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
@@ -257,9 +233,15 @@ def SpatSumOp: SpatOp<"sum", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||||
|
let summary = "Element-wise sigmoid activation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
@@ -267,9 +249,15 @@ def SpatSigmoidOp: SpatOp<"sigmoid", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatReluOp : SpatOp<"relu", []> {
|
def SpatReluOp : SpatOp<"relu", []> {
|
||||||
|
let summary = "Element-wise ReLU activation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
@@ -277,15 +265,18 @@ def SpatReluOp: SpatOp<"relu", []> {
|
|||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatVMaxOp : SpatOp<"vmax", []> {
|
def SpatVMaxOp : SpatOp<"vmax", []> {
|
||||||
|
let summary = "Element-wise max between two tensors";
|
||||||
let summary = "Element-wise max function";
|
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor: $a,
|
SpatTensor:$lhs,
|
||||||
SpatTensor: $b
|
SpatTensor:$rhs
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -293,62 +284,9 @@ def SpatVMaxOp: SpatOp<"vmax", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
|
||||||
|
|
||||||
def SpatApplyFiltersOp : SpatOp<"apply_filters", []> {
|
|
||||||
let summary = "Apply multiple crossbar weights to a convolutional input tile.";
|
|
||||||
let description = [{
|
|
||||||
Applies a variable number of crossbar weights to a single large image tensor tile,
|
|
||||||
producing a corresponding output tile. This essentially encapsulates a big for loop
|
|
||||||
over all pixels in the input tile, where each pixel is multiplied by all the weights
|
|
||||||
in the operation.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I64ArrayAttr: $weightIndices,
|
|
||||||
I64ArrayAttr: $xKernelPositions,
|
|
||||||
I64ArrayAttr: $yKernelPositions,
|
|
||||||
SpatTensor: $input
|
|
||||||
);
|
|
||||||
let results = (outs SpatTensor);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$input attr-dict `:` type($input) `->` type(results)
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Other operations
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def SpatImgConcatOp: SpatOp<"img_concat", []> {
|
|
||||||
|
|
||||||
let summary = "Concatenate pixel tiles into a single image";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Concatenate pixel tiles into a single image:
|
|
||||||
1. First, concatenate the pixel tiles along the "channel" axis (axis 1).
|
|
||||||
2. Next, concatenate the pixel tiles along the "width" axis (axis 2).
|
|
||||||
3. Finally, concatenate the pixel tiles along the "height" axis (axis 3).
|
|
||||||
|
|
||||||
The input tiles should be provided in a specific order:
|
|
||||||
start from the top left pixel,
|
|
||||||
then continue with the pixel on its right,
|
|
||||||
and once you finish the first row of pixels, go to the next row.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<SpatTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::Value getInputTile(size_t x, size_t y, size_t tile);
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,8 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
|||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getVector().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
|
|
||||||
/* Two possible accepted shapes:
|
/* Two possible accepted shapes:
|
||||||
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
|
|||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getVector().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
|
|
||||||
/* Accepted shape:
|
/* Accepted shape:
|
||||||
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
|
|||||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatImgConcatOp::verify() {
|
|
||||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
||||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
|
||||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
|
||||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
|
||||||
|
|
||||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
||||||
size_t channelTileRest = img_c % crossbarSize;
|
|
||||||
|
|
||||||
auto operands = getOperands();
|
|
||||||
|
|
||||||
// Check number of operands
|
|
||||||
if (img_w * img_h * channelTiles != operands.size())
|
|
||||||
return emitError("Number of operands does not match output image size");
|
|
||||||
|
|
||||||
// For each output pixel, check that the inputTiles have a correct shape
|
|
||||||
for (size_t x = 0; x < img_w; x++) {
|
|
||||||
for (size_t y = 0; y < img_h; y++) {
|
|
||||||
size_t channel_counts = 0;
|
|
||||||
for (size_t t = 0; t < channelTiles; t++) {
|
|
||||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
|
||||||
if (!inputShape)
|
|
||||||
return emitError("Invalid input type, must be ShapedType");
|
|
||||||
|
|
||||||
// N == W == H == 1
|
|
||||||
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
|
|
||||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
|
||||||
|
|
||||||
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
|
|
||||||
|
|
||||||
// Check the number of channels in this tile are correct:
|
|
||||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
|
||||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
|
||||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
|
||||||
if (inputChannels != channelTileRest)
|
|
||||||
return emitError("Invalid channel count for last tile of pixel");
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (inputChannels != crossbarSize)
|
|
||||||
return emitError("Invalid channel count for some pixel tile");
|
|
||||||
}
|
|
||||||
|
|
||||||
channel_counts += inputChannels;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (channel_counts != img_c)
|
|
||||||
emitError("Invalid number of channels for some pixel");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult SpatWeightedCompute::verify() {
|
LogicalResult SpatWeightedCompute::verify() {
|
||||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||||
// operand with the same type as the result
|
// operand with the same type as the result
|
||||||
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
|
||||||
auto operands = getOperands();
|
|
||||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
||||||
size_t img_w = GET_IMAGE_WIDTH(imgShape);
|
|
||||||
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
|
|
||||||
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
|
|
||||||
|
|
||||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
||||||
|
|
||||||
assert(tile < channelTiles);
|
|
||||||
assert(x < img_w);
|
|
||||||
assert(y < img_h);
|
|
||||||
|
|
||||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -34,11 +34,79 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
|||||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||||
|
|
||||||
// Alloc an output memref
|
// Alloc an output memref
|
||||||
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
|
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
|
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||||
|
return memrefValue;
|
||||||
|
|
||||||
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
|
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||||
|
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
return pim::PimMemCopyOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
contiguousBuffer.getType(),
|
||||||
|
contiguousBuffer,
|
||||||
|
memrefValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||||
|
|
||||||
|
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||||
|
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
if (!channelNewOp) {
|
||||||
|
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto channelUsers = channelNewOp->getUsers();
|
||||||
|
auto usersIterator = channelUsers.begin();
|
||||||
|
auto firstUser = *usersIterator;
|
||||||
|
++usersIterator;
|
||||||
|
if (usersIterator == channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
||||||
|
channelNewOp->dump();
|
||||||
|
op->dump();
|
||||||
|
channelNewOp->getParentOp()->dump();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto secondUser = *usersIterator;
|
||||||
|
++usersIterator;
|
||||||
|
if (usersIterator != channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* otherUser = nullptr;
|
||||||
|
if (firstUser == op)
|
||||||
|
otherUser = secondUser;
|
||||||
|
else if (secondUser == op)
|
||||||
|
otherUser = firstUser;
|
||||||
|
else {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return otherUser;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||||
|
|
||||||
// This function requires the existence of ChannelNewOp and the other
|
// This function requires the existence of ChannelNewOp and the other
|
||||||
@@ -49,7 +117,7 @@ llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsR
|
|||||||
if (precomputedOtherCoreId)
|
if (precomputedOtherCoreId)
|
||||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||||
|
|
||||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
|
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
||||||
if (failed(notOpUserOpt))
|
if (failed(notOpUserOpt))
|
||||||
return failure();
|
return failure();
|
||||||
Operation* notOpUser = *notOpUserOpt;
|
Operation* notOpUser = *notOpUserOpt;
|
||||||
@@ -119,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
|||||||
auto memref = getBuffer(rewriter, operand, options, state);
|
auto memref = getBuffer(rewriter, operand, options, state);
|
||||||
if (failed(memref))
|
if (failed(memref))
|
||||||
return failure();
|
return failure();
|
||||||
memrefOperands.push_back(*memref);
|
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support addiction with more than 2 operands
|
// TODO: Support addiction with more than 2 operands
|
||||||
@@ -134,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
|||||||
|
|
||||||
memrefOperands.push_back(outputTensor);
|
memrefOperands.push_back(outputTensor);
|
||||||
|
|
||||||
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -169,11 +237,13 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
|||||||
// Alloc an output memref
|
// Alloc an output memref
|
||||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||||
|
|
||||||
Value newValue =
|
Value newValue = ToTy::create(rewriter,
|
||||||
rewriter
|
op->getLoc(),
|
||||||
.create<ToTy>(
|
outputTensor.getType(),
|
||||||
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
|
cast<OpTy>(op).getWeightIndexAttr(),
|
||||||
.getOutRes();
|
memrefOperand,
|
||||||
|
outputTensor)
|
||||||
|
.getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -213,13 +283,13 @@ struct ChannelReceiveOpInterface
|
|||||||
if (failed(srcCoreId))
|
if (failed(srcCoreId))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value newValue = rewriter
|
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||||
.create<pim::PimReceiveOp>(op->getLoc(),
|
op->getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||||
.getOut();
|
.getOutput();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -300,7 +370,8 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||||
|
|
||||||
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
|
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||||
if (!channelNewOp) {
|
if (!channelNewOp) {
|
||||||
@@ -323,7 +394,8 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
|
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
op->getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
bufferAllocation,
|
bufferAllocation,
|
||||||
@@ -331,7 +403,7 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
rewriter.getI32IntegerAttr(outputSize));
|
rewriter.getI32IntegerAttr(outputSize));
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
|
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -356,7 +428,8 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Turn the channel send to pim.send
|
* Turn the channel send into a device-to-host copy into the shared
|
||||||
|
* broadcast buffer that receive ops load from later.
|
||||||
*/
|
*/
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
RewriterBase& rewriter,
|
RewriterBase& rewriter,
|
||||||
@@ -389,101 +462,31 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||||
|
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
op->getLoc(),
|
||||||
|
bufferAllocation.getType(),
|
||||||
|
bufferAllocation,
|
||||||
|
srcMemRef,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||||
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct VAddOpInterfaceFromTemplate
|
struct VAddOpInterfaceFromTemplate
|
||||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||||
|
|
||||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||||
|
|
||||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||||
|
|
||||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||||
|
|
||||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
|
||||||
|
|
||||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
|
||||||
|
|
||||||
// Create a new bufferizable op interface for the apply filters operation.
|
|
||||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
|
||||||
|
|
||||||
// One operand ($input) is read from. All other inputs are only written to.
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
|
|
||||||
// Operand 0: $input
|
|
||||||
// Operand 1: $outBuf
|
|
||||||
// Operand 2: $accumBuf
|
|
||||||
return opOperand.getOperandNumber() == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// One input ($accumBuf) is written to. All other inputs are only read.
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
|
|
||||||
// Operand 0: $input
|
|
||||||
// Operand 1: $outBuf
|
|
||||||
// Operand 2: $accumBuf
|
|
||||||
return opOperand.getOperandNumber() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No operands are aliased with any other operands.
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bufferize the operation.
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
|
|
||||||
// Get the input tensor buffer.
|
|
||||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
|
||||||
|
|
||||||
if (failed(inputBuffer))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Create a new buffer for the output tensor.
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
// Create a new buffer for the accumulation buffer.
|
|
||||||
// To do this, create a new allocation operation. Size must be axbx1x1,
|
|
||||||
// where axbxcxd is the size of the output tensor. Since the shape is
|
|
||||||
// different, we can't immediately use createEmptyFromType, we first need to
|
|
||||||
// create the shape of the accumulation buffer.
|
|
||||||
auto accumShape = llvm::to_vector<4>(cast<ShapedType>(op->getResult(0).getType()).getShape());
|
|
||||||
|
|
||||||
// Set the last two dimensions to 1.
|
|
||||||
accumShape[accumShape.size() - 1] = 1;
|
|
||||||
accumShape[accumShape.size() - 2] = 1;
|
|
||||||
|
|
||||||
auto accumType = MemRefType::get(accumShape, cast<ShapedType>(op->getResult(0).getType()).getElementType());
|
|
||||||
|
|
||||||
auto accumBuffer = createEmptyFromType(accumType, op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
// Bufferize the operation.
|
|
||||||
auto weightIndices = cast<SpatApplyFiltersOp>(op).getWeightIndicesAttr();
|
|
||||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
|
||||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
|
||||||
|
|
||||||
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
weightIndices,
|
|
||||||
xKernelPositions,
|
|
||||||
yKernelPositions,
|
|
||||||
*inputBuffer,
|
|
||||||
outputTensor,
|
|
||||||
accumBuffer);
|
|
||||||
|
|
||||||
// Replace the operation with the bufferized value.
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, bufferized);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||||
@@ -491,25 +494,26 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|||||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
|
||||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
|
||||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||||
SpatApplyFiltersOp::attachInterface<ApplyFiltersOpInterface>(*ctx);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||||
|
|
||||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||||
|
|
||||||
|
struct ONNXSigmoidInterface
|
||||||
|
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||||
|
|
||||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||||
|
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,14 +4,12 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||||
|
|
||||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
17
src/PIM/Pass/CMakeLists.txt
Normal file
17
src/PIM/Pass/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
add_pim_library(OMPimPasses
|
||||||
|
CountInstructionPass.cpp
|
||||||
|
MessagePass.cpp
|
||||||
|
Pim/ConstantFolding/Common.cpp
|
||||||
|
Pim/ConstantFolding/Patterns/Constant.cpp
|
||||||
|
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
||||||
|
Pim/ConstantFolding/Patterns/Subview.cpp
|
||||||
|
Pim/MaterializeConstantsPass.cpp
|
||||||
|
Pim/VerificationPass.cpp
|
||||||
|
Pim/EmitPimJsonPass.cpp
|
||||||
|
|
||||||
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRLinalgDialect
|
||||||
|
OMPimCommon
|
||||||
|
)
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
|
|||||||
30
src/PIM/Pass/PIMPasses.h
Normal file
30
src/PIM/Pass/PIMPasses.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::Pass> createCountInstructionPass();
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
119
src/PIM/Pass/Pim/ConstantFolding/Common.cpp
Normal file
119
src/PIM/Pass/Pim/ConstantFolding/Common.cpp
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value stripMemRefCasts(Value value) {
|
||||||
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
|
value = castOp.getSource();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value stripMemRefViewOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||||
|
Location loc,
|
||||||
|
MemRefType globalType,
|
||||||
|
DenseElementsAttr denseAttr,
|
||||||
|
StringRef nameStem,
|
||||||
|
IntegerAttr alignment) {
|
||||||
|
auto globalName = nameStem.str();
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (moduleOp.lookupSymbol(globalName))
|
||||||
|
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||||
|
|
||||||
|
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||||
|
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||||
|
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
return memref::GlobalOp::create(moduleBuilder,
|
||||||
|
loc,
|
||||||
|
globalName,
|
||||||
|
visibility,
|
||||||
|
globalType,
|
||||||
|
denseAttr,
|
||||||
|
/*constant=*/true,
|
||||||
|
alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||||
|
value = stripMemRefCasts(value);
|
||||||
|
|
||||||
|
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
return denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
value = stripMemRefViewOps(value);
|
||||||
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
StaticSubviewInfo info;
|
||||||
|
info.source = source;
|
||||||
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
info.offsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto staticSize = getConstantIntValue(size);
|
||||||
|
if (!staticSize)
|
||||||
|
return failure();
|
||||||
|
info.sizes.push_back(*staticSize);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
info.strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t
|
||||||
|
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(info.sourceShape.size());
|
||||||
|
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||||
|
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||||
|
sourceIndices.push_back(info.offsets.back());
|
||||||
|
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
41
src/PIM/Pass/Pim/ConstantFolding/Common.hpp
Normal file
41
src/PIM/Pass/Pim/ConstantFolding/Common.hpp
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct StaticSubviewInfo {
|
||||||
|
mlir::Value source;
|
||||||
|
llvm::SmallVector<int64_t> sourceShape;
|
||||||
|
llvm::SmallVector<int64_t> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::MemRefType globalType,
|
||||||
|
mlir::DenseElementsAttr denseAttr,
|
||||||
|
llvm::StringRef nameStem,
|
||||||
|
mlir::IntegerAttr alignment = {});
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||||
|
|
||||||
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
|
|
||||||
|
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||||
|
llvm::ArrayRef<int64_t> outerIndices,
|
||||||
|
int64_t elementByteWidth);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
52
src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp
Normal file
52
src/PIM/Pass/Pim/ConstantFolding/ConstantFoldingPass.cpp
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||||
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
|
|
||||||
|
LogicalResult initialize(MLIRContext* context) override {
|
||||||
|
RewritePatternSet owningPatterns(context);
|
||||||
|
for (auto* dialect : context->getLoadedDialects())
|
||||||
|
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||||
|
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||||
|
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||||
|
|
||||||
|
populateConstantFoldingConstantPatterns(owningPatterns);
|
||||||
|
populateConstantFoldingSubviewPatterns(owningPatterns);
|
||||||
|
|
||||||
|
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
config.enableFolding();
|
||||||
|
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dumpModule(getOperation(), "pim2_folded");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
11
src/PIM/Pass/Pim/ConstantFolding/Patterns.hpp
Normal file
11
src/PIM/Pass/Pim/ConstantFolding/Patterns.hpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
485
src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp
Normal file
485
src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
|
||||||
|
#include "../Common.hpp"
|
||||||
|
#include "../Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ConstantSubviewCopy {
|
||||||
|
DenseElementsAttr source;
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
Operation* copyOp = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!tensorType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
if (static_cast<int64_t>(perms.size()) != rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
llvm::SmallBitVector seen(rank);
|
||||||
|
SmallVector<int64_t> transposedShape;
|
||||||
|
transposedShape.reserve(rank);
|
||||||
|
for (int64_t perm : perms) {
|
||||||
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||||
|
return failure();
|
||||||
|
seen.set(perm);
|
||||||
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||||
|
|
||||||
|
SmallVector<int64_t> originalStrides(rank, 1);
|
||||||
|
SmallVector<int64_t> transposedStrides(rank, 1);
|
||||||
|
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
||||||
|
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
||||||
|
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> originalIndices(rank);
|
||||||
|
SmallVector<int64_t> transposedIndices(rank);
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||||
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
originalIndices[dim] = remaining / originalStrides[dim];
|
||||||
|
remaining %= originalStrides[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedIndices[dim] = originalIndices[perms[dim]];
|
||||||
|
|
||||||
|
int64_t transposedLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
||||||
|
|
||||||
|
transposedValues[transposedLinearIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||||
|
if (!mapOp.getInputs().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Attribute attr;
|
||||||
|
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
|
||||||
|
return failure();
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
||||||
|
if (!coreOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||||
|
if (!initType || !initType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto fillValue = getConstantMapYield(mapOp);
|
||||||
|
if (failed(fillValue))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
|
||||||
|
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
|
||||||
|
|
||||||
|
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(coreOp);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(mapOp);
|
||||||
|
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||||
|
rewriter.eraseOp(mapOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||||
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
const int64_t numElements = resultTensorType.getNumElements();
|
||||||
|
if (numElements < 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Attribute fillValue;
|
||||||
|
SmallVector<ConstantSubviewCopy> copies;
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
|
||||||
|
SmallVector<Value> pendingAliases;
|
||||||
|
pendingAliases.push_back(allocOp.getResult());
|
||||||
|
|
||||||
|
while (!pendingAliases.empty()) {
|
||||||
|
Value alias = pendingAliases.pop_back_val();
|
||||||
|
for (Operation* user : alias.getUsers()) {
|
||||||
|
if (!visitedAliases.insert(user).second)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
|
||||||
|
if (mapOp.getInit() != alias)
|
||||||
|
return failure();
|
||||||
|
auto maybeFillValue = getConstantMapYield(mapOp);
|
||||||
|
if (failed(maybeFillValue))
|
||||||
|
return failure();
|
||||||
|
if (fillValue && fillValue != *maybeFillValue)
|
||||||
|
return failure();
|
||||||
|
fillValue = *maybeFillValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
strides.reserve(subviewOp.getMixedStrides().size());
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
offsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* subviewUser : subviewOp->getUsers()) {
|
||||||
|
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
|
||||||
|
if (copyOp.getTarget() != subviewOp.getResult())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
copies.push_back({*denseAttr, offsets, strides, copyOp});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||||
|
pendingAliases.push_back(castOp.getResult());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fillValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Attribute> resultValues(numElements, fillValue);
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
|
||||||
|
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
|
||||||
|
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const ConstantSubviewCopy& copy : copies) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|
||||||
|
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
|
||||||
|
SmallVector<int64_t> sourceIndices =
|
||||||
|
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
|
||||||
|
SmallVector<int64_t> resultIndices;
|
||||||
|
resultIndices.reserve(sourceIndices.size());
|
||||||
|
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
|
||||||
|
resultIndices.push_back(offset + sourceIndex * stride);
|
||||||
|
|
||||||
|
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
|
||||||
|
resultValues[resultLinearIndex] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!sourceGetGlobal)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
||||||
|
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> perms;
|
||||||
|
perms.reserve(transposeOp.getPermutation().size());
|
||||||
|
for (IntegerAttr attr : transposeOp.getPermutation().getAsRange<IntegerAttr>())
|
||||||
|
perms.push_back(attr.getInt());
|
||||||
|
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||||
|
if (failed(transposedAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
||||||
|
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||||
|
transposeOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
*transposedAttr,
|
||||||
|
sourceGlobal.getName().str() + "__folded_transpose",
|
||||||
|
sourceGlobal.getAlignmentAttr());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(transposeOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||||
|
|
||||||
|
bool isAlwaysWeight =
|
||||||
|
!transposeOp->getUsers().empty()
|
||||||
|
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
|
if (isAlwaysWeight) {
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
|
||||||
|
if (failed(foldedAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto allocType = cast<MemRefType>(allocOp.getType());
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(allocOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||||
|
|
||||||
|
SmallVector<Operation*> opsToErase;
|
||||||
|
SmallVector<memref::CastOp> castsToReplace;
|
||||||
|
bool allLiveUsersAreCoreOps = true;
|
||||||
|
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
|
||||||
|
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
|
||||||
|
opsToErase.push_back(user);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||||
|
castsToReplace.push_back(castOp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<pim::PimCoreOp>(user))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||||
|
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
|
})) {
|
||||||
|
allLiveUsersAreCoreOps = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||||
|
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
||||||
|
})) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allLiveUsersAreCoreOps) {
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
|
||||||
|
for (memref::CastOp castOp : castsToReplace)
|
||||||
|
preservedUsers.insert(castOp);
|
||||||
|
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
|
||||||
|
|
||||||
|
for (memref::CastOp castOp : castsToReplace) {
|
||||||
|
rewriter.setInsertionPoint(castOp);
|
||||||
|
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
|
||||||
|
rewriter.replaceOp(castOp, replacementCast);
|
||||||
|
if (allLiveUsersAreCoreOps)
|
||||||
|
markWeightAlways(replacementCast.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||||
|
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
|
||||||
|
rewriter.eraseOp(subviewUser);
|
||||||
|
if (op->use_empty())
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allocOp.use_empty())
|
||||||
|
rewriter.eraseOp(allocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||||
|
if (!allocOp)
|
||||||
|
return failure();
|
||||||
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||||
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||||
|
|
||||||
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseElementsAttr foldedAttr;
|
||||||
|
if (succeeded(srcSubview)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(numResultElements);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||||
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultIndices.size());
|
||||||
|
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||||
|
sourceIndices.push_back(off + idx);
|
||||||
|
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||||
|
resultValues[i] = sourceValues[srcLinear];
|
||||||
|
}
|
||||||
|
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
if (resultTensorType != denseAttr->getType())
|
||||||
|
return failure();
|
||||||
|
foldedAttr = *denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allLiveUsersAreCores = true;
|
||||||
|
for (Operation* user : allocOp->getUsers()) {
|
||||||
|
if (user == copyOp)
|
||||||
|
continue;
|
||||||
|
if (isa<memref::DeallocOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<pim::PimCoreOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<memref::SubViewOp>(user)) {
|
||||||
|
allLiveUsersAreCores = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(allocOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
|
||||||
|
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||||
|
rewriter.eraseOp(copyOp);
|
||||||
|
if (allocOp.use_empty())
|
||||||
|
rewriter.eraseOp(allocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns
|
||||||
|
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user