Compare commits
6 Commits
2676f2c7ef
...
a1b29dffe0
| Author | SHA1 | Date | |
|---|---|---|---|
| a1b29dffe0 | |||
| 661170a9aa | |||
| 461bdd808d | |||
| 50c545539b | |||
| 11916a2595 | |||
| 670d6ce94f |
@@ -0,0 +1,54 @@
|
||||
name: Prepare MLIR Cache
|
||||
description: Restore or build the cached MLIR toolchain.
|
||||
|
||||
inputs:
|
||||
llvm-commit:
|
||||
description: LLVM commit to build.
|
||||
required: true
|
||||
mold-linker-flags:
|
||||
description: Linker flags used to force mold.
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Restore MLIR cache
|
||||
id: restore-mlir-cache
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: onnx-mlir/llvm-project
|
||||
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
||||
|
||||
- name: Clone LLVM
|
||||
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
|
||||
cd onnx-mlir/llvm-project
|
||||
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
- name: Build MLIR
|
||||
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
cmake -S onnx-mlir/llvm-project/llvm -B onnx-mlir/llvm-project/build -G Ninja \
|
||||
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
|
||||
-DLLVM_ENABLE_RUNTIMES="openmp" \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_ENABLE_RTTI=ON \
|
||||
-DENABLE_LIBOMPTARGET=OFF \
|
||||
-DLLVM_ENABLE_LIBEDIT=OFF \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
|
||||
cmake --build onnx-mlir/llvm-project/build
|
||||
|
||||
- name: Save MLIR cache
|
||||
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: onnx-mlir/llvm-project
|
||||
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
||||
@@ -0,0 +1,45 @@
|
||||
name: Prepare Protobuf Cache
|
||||
description: Restore or build the cached protobuf installation.
|
||||
|
||||
inputs:
|
||||
protobuf-commit:
|
||||
description: Protobuf tag or commit to build.
|
||||
required: true
|
||||
mold-linker-flags:
|
||||
description: Linker flags used to force mold.
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Restore protobuf cache
|
||||
id: restore-protobuf-cache
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/usr/local/lib/libproto*
|
||||
/usr/local/include/google/protobuf
|
||||
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
|
||||
|
||||
- name: Install protobuf
|
||||
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
git clone --depth 1 --branch ${{ inputs.protobuf-commit }} https://github.com/protocolbuffers/protobuf
|
||||
cmake -S protobuf -B protobuf/build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
|
||||
cmake --build protobuf/build
|
||||
sudo cmake --install protobuf/build
|
||||
rm -rf protobuf
|
||||
|
||||
- name: Save protobuf cache
|
||||
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/usr/local/lib/libproto*
|
||||
/usr/local/include/google/protobuf
|
||||
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
|
||||
@@ -0,0 +1,26 @@
|
||||
name: Restore Raptor Build Cache
|
||||
description: Restore the cached raptor build directory for incremental builds.
|
||||
|
||||
inputs:
|
||||
key:
|
||||
description: Exact cache key to restore.
|
||||
required: true
|
||||
restore-keys:
|
||||
description: Prefixes used to restore the most recent compatible cache.
|
||||
required: false
|
||||
|
||||
outputs:
|
||||
cache-hit:
|
||||
description: Whether the exact cache key was restored.
|
||||
value: ${{ steps.restore-raptor-build-cache.outputs.cache-hit }}
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Restore raptor build cache
|
||||
id: restore-raptor-build-cache
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: build
|
||||
key: ${{ inputs.key }}
|
||||
restore-keys: ${{ inputs.restore-keys }}
|
||||
@@ -0,0 +1,16 @@
|
||||
name: Save Raptor Build Cache
|
||||
description: Save the raptor build directory after a successful incremental build.
|
||||
|
||||
inputs:
|
||||
key:
|
||||
description: Cache key used to save the build directory.
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Save raptor build cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: build
|
||||
key: ${{ inputs.key }}
|
||||
@@ -1,50 +0,0 @@
|
||||
name: Build MLIR Cache
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
llvm-commit:
|
||||
required: true
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
build-mlir:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Cache MLIR build
|
||||
id: cache-mlir
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: onnx-mlir/llvm-project
|
||||
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
|
||||
|
||||
- name: Install build dependencies
|
||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y cmake ninja-build build-essential
|
||||
|
||||
- name: Clone LLVM
|
||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
|
||||
cd onnx-mlir/llvm-project
|
||||
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
- name: Build MLIR
|
||||
if: steps.cache-mlir.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
mkdir -p onnx-mlir/llvm-project/build
|
||||
cd onnx-mlir/llvm-project/build
|
||||
cmake -G Ninja ../llvm \
|
||||
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
|
||||
-DLLVM_ENABLE_RUNTIMES="openmp" \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_ENABLE_RTTI=ON \
|
||||
-DENABLE_LIBOMPTARGET=OFF \
|
||||
-DLLVM_ENABLE_LIBEDIT=OFF
|
||||
cmake --build .
|
||||
@@ -4,29 +4,14 @@ env:
|
||||
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
|
||||
PROTOBUF_COMMIT: v34.0
|
||||
CMAKE_VERSION: 4.3.0
|
||||
MOLD_LINKER_FLAGS: -fuse-ld=mold
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
# Expose env vars as outputs so they can be passed to reusable workflows
|
||||
config:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
llvm-commit: ${{ steps.set-vars.outputs.llvm-commit }}
|
||||
steps:
|
||||
- id: set-vars
|
||||
run: echo "llvm-commit=$LLVM_COMMIT" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-mlir-cache:
|
||||
needs: config
|
||||
uses: ./.github/workflows/build_mlir_cache.yml
|
||||
with:
|
||||
llvm-commit: ${{ needs.config.outputs.llvm-commit }}
|
||||
|
||||
validate:
|
||||
needs: build-mlir-cache
|
||||
validate-operations:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@@ -39,7 +24,7 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y ninja-build build-essential curl ca-certificates
|
||||
sudo apt install -y cmake ninja-build build-essential mold curl ca-certificates
|
||||
|
||||
- name: Install CMake
|
||||
run: |
|
||||
@@ -58,27 +43,17 @@ jobs:
|
||||
cmake --version
|
||||
which cmake
|
||||
|
||||
- name: Cache protobuf build
|
||||
id: cache-protobuf
|
||||
uses: actions/cache@v4
|
||||
- name: Prepare MLIR cache
|
||||
uses: ./.github/actions/prepare-mlir-cache
|
||||
with:
|
||||
path: |
|
||||
/usr/local/lib/libproto*
|
||||
/usr/local/include/google/protobuf
|
||||
key: protobuf-${{ runner.os }}-${{ env.PROTOBUF_COMMIT }}
|
||||
llvm-commit: ${{ env.LLVM_COMMIT }}
|
||||
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
|
||||
|
||||
- name: Install protobuf
|
||||
if: steps.cache-protobuf.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
git clone --depth 1 --branch ${{ env.PROTOBUF_COMMIT }} https://github.com/protocolbuffers/protobuf
|
||||
cd protobuf
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
|
||||
ninja
|
||||
sudo ninja install
|
||||
cd ../..
|
||||
rm -rf protobuf
|
||||
- name: Prepare protobuf cache
|
||||
uses: ./.github/actions/prepare-protobuf-cache
|
||||
with:
|
||||
protobuf-commit: ${{ env.PROTOBUF_COMMIT }}
|
||||
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
|
||||
|
||||
- name: Register installed libraries
|
||||
run: sudo ldconfig
|
||||
@@ -94,23 +69,34 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
run: pip install numpy onnx colorama
|
||||
|
||||
- name: Restore MLIR cache
|
||||
uses: actions/cache/restore@v4
|
||||
- name: Restore raptor build cache
|
||||
id: restore-raptor-build-cache
|
||||
uses: ./.github/actions/restore-raptor-build-cache
|
||||
with:
|
||||
path: onnx-mlir/llvm-project
|
||||
key: mlir-${{ runner.os }}-${{ env.LLVM_COMMIT }}
|
||||
fail-on-cache-miss: true
|
||||
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
|
||||
raptor-build-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
|
||||
|
||||
- name: Build Raptor
|
||||
id: build-raptor
|
||||
run: |
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
|
||||
mkdir -p build && cd build
|
||||
cmake .. -G Ninja \
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR}
|
||||
cmake --build .
|
||||
-DMLIR_DIR=${MLIR_DIR} \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}"
|
||||
cmake --build build
|
||||
|
||||
- name: Save raptor build cache
|
||||
if: steps.build-raptor.outcome == 'success' && steps.restore-raptor-build-cache.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/save-raptor-build-cache
|
||||
with:
|
||||
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
|
||||
|
||||
- name: Run validation
|
||||
run: |
|
||||
|
||||
@@ -29,7 +29,7 @@ Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
|
||||
|
||||
Moreover, if compiling with build type debug, it is also suggested to use
|
||||
mold as linker (you will need to install it if you don't have it already)
|
||||
to reduce memory usage during linking. You can use it with:
|
||||
to reduce memory usage during linking. You can use it by setting the options:
|
||||
```
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
|
||||
@@ -38,8 +38,16 @@ to reduce memory usage during linking. You can use it with:
|
||||
### Raptor
|
||||
|
||||
Use the following commands to build Raptor.
|
||||
|
||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
|
||||
|
||||
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
|
||||
setting the options:
|
||||
```
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
|
||||
```
|
||||
|
||||
```
|
||||
git submodule update --init --recursive
|
||||
|
||||
|
||||
@@ -530,6 +530,7 @@ where
|
||||
let r2_val = r2;
|
||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||
let rd_val = core.register(rd);
|
||||
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||
let load1 = loads[0];
|
||||
|
||||
@@ -224,7 +224,21 @@ fn json_to_vvsub(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vvsub", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vvdmul", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -306,7 +334,21 @@ fn json_to_vavg(
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
todo!("Not present in the compiler");
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vavg", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let rs2 = json_i64!(json, "rs2") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_r1(rs1)
|
||||
.set_r2(rs2)
|
||||
.set_imm_len(len)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -358,7 +400,7 @@ fn json_to_vsigm(
|
||||
json: &Value,
|
||||
) -> Result<()> {
|
||||
let json = json.as_object().expect("Not an object");
|
||||
assert_eq!("vsigmoid", json_str!(json, "op"));
|
||||
assert_eq!("vsigm", json_str!(json, "op"));
|
||||
let rd = json_i64!(json, "rd") as i32;
|
||||
let rs1 = json_i64!(json, "rs1") as i32;
|
||||
let len = json_i64!(json, "len") as i32;
|
||||
|
||||
+38
-17
@@ -10,34 +10,54 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
add_subdirectory(Common)
|
||||
add_subdirectory(Compiler)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
|
||||
add_onnx_mlir_library(OMPIMAccel
|
||||
PimAccelerator.cpp
|
||||
Pass/CountInstructionPass.cpp
|
||||
Pass/EmitPimJsonPass.cpp
|
||||
Pass/MessagePass.cpp
|
||||
Pass/PimConstantFoldingPass.cpp
|
||||
Pass/PimHostVerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_SRC_ROOT}
|
||||
${PIM_BIN_ROOT}
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
|
||||
set(PIM_COMPILER_INCLUDE_DIRS
|
||||
${PIM_SRC_ROOT}
|
||||
${PIM_BIN_ROOT}
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
)
|
||||
|
||||
set(PIM_ACCEL_INCLUDE_DIRS
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
)
|
||||
|
||||
set(PIM_GENERATED_INCLUDE_DIRS
|
||||
${PIM_INCLUDE_PATH}
|
||||
)
|
||||
|
||||
function(add_pim_library name)
|
||||
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Common)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Compiler)
|
||||
add_subdirectory(Conversion)
|
||||
|
||||
add_pim_library(OMPIMAccel
|
||||
PimAccelerator.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
onnx
|
||||
OMAccelerator
|
||||
OMPimCompilerUtils
|
||||
OMCompilerUtils
|
||||
OMPimPasses
|
||||
OMONNXOps
|
||||
SpatialOps
|
||||
PimOps
|
||||
@@ -45,5 +65,6 @@ add_onnx_mlir_library(OMPIMAccel
|
||||
OMSpatialToGraphviz
|
||||
OMSpatialToPim
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
MLIRTensorInferTypeOpInterfaceImpl
|
||||
)
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
add_onnx_mlir_library(OMPimCommon
|
||||
add_pim_library(OMPimCommon
|
||||
PimCommon.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_SRC_ROOT}
|
||||
${PIM_BIN_ROOT}
|
||||
${PIM_INCLUDE_PATH}
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
onnx
|
||||
OMPimCompilerUtils
|
||||
SpatialOps
|
||||
PimOps
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
int64_t byteOffset = 0;
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = tiedOperand->get();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||
return failure();
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
template <typename T>
|
||||
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
|
||||
public:
|
||||
llvm::DenseMap<mlir::Value, T> map;
|
||||
|
||||
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
|
||||
: ForwardingListener(&listener) {}
|
||||
|
||||
void notifyOperationErased(mlir::Operation* op) override {
|
||||
for (mlir::Value result : op->getResults())
|
||||
map.erase(result);
|
||||
}
|
||||
|
||||
void notifyBlockErased(mlir::Block* block) override {
|
||||
for (mlir::BlockArgument arg : block->getArguments())
|
||||
map.erase(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
|
||||
public:
|
||||
llvm::DenseMap<mlir::Value, T> map;
|
||||
|
||||
NotErasableValueMap(mlir::OpBuilder::Listener listener)
|
||||
: ForwardingListener(&listener) {}
|
||||
|
||||
void notifyOperationErased(mlir::Operation* op) override {
|
||||
for (mlir::Value result : op->getResults())
|
||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
|
||||
}
|
||||
|
||||
void notifyBlockErased(mlir::Block* block) override {
|
||||
for (mlir::BlockArgument arg : block->getArguments())
|
||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
|
||||
}
|
||||
};
|
||||
@@ -1,44 +1,37 @@
|
||||
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
|
||||
|
||||
add_onnx_mlir_library(OMPimCompilerOptions
|
||||
add_pim_library(OMPimCompilerOptions
|
||||
PimCompilerOptions.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PRIVATE
|
||||
${PIM_SRC_ROOT}
|
||||
${PIM_BIN_ROOT}
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
${PIM_COMPILER_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
${OMLibs}
|
||||
OMCompilerOptions
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
${PIM_ACCEL_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
add_onnx_mlir_library(OMPimCompilerUtils
|
||||
add_pim_library(OMPimCompilerUtils
|
||||
PimCompilerUtils.cpp
|
||||
PimCodeGen.cpp
|
||||
../Pass/EmitPimJsonPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PRIVATE
|
||||
${PIM_SRC_ROOT}
|
||||
${PIM_BIN_ROOT}
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
${PIM_COMPILER_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
${OMLibs}
|
||||
OMCompilerUtils
|
||||
OMPimCompilerOptions
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimPasses
|
||||
OMONNXToSpatial
|
||||
OMSpatialToPim
|
||||
OMCompilerPasses
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_ONNX_MLIR_SRC_ROOT}
|
||||
${PIM_ONNX_MLIR_BIN_ROOT}
|
||||
${PIM_ACCEL_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
+142
-62
@@ -14,12 +14,11 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -86,48 +85,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
}
|
||||
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
size_t offset = 0;
|
||||
while (true) {
|
||||
auto definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
break;
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
break;
|
||||
value = tiedOperand->get();
|
||||
}
|
||||
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto source = subviewDefiningOp.getSource();
|
||||
auto srcShape = source.getType().getShape();
|
||||
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
|
||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
||||
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
||||
size_t localOffset = subviewOffsets[i];
|
||||
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
||||
localOffset *= subviewSizes[j];
|
||||
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
}
|
||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
}
|
||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(value);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress)) {
|
||||
errs() << "Failed to resolve contiguous address for value: ";
|
||||
value.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
@@ -135,10 +95,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Failed to resolve contiguous address");
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(resolvedAddress->base);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
resolvedAddress->base.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + offset;
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
@@ -264,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
|
||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vaddOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vaddOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
|
||||
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
|
||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvaddOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvaddOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvadd";
|
||||
@@ -279,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = totalBytes;
|
||||
json["len"] = getValueSizeInBytes(vvaddOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vmaxOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vmaxOp.getB());
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvsubOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvsubOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvsub";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvsubOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvmulOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvmulOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmulOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -295,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
|
||||
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
|
||||
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvdmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vavgOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vavg";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vavgOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
@@ -308,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vreluOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vtanhOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vtanh";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vtanhOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
|
||||
auto aAddr = memory.getValueAddress(vsigmOp.getA());
|
||||
setupRdRs1(outBufAddr, 0, aAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsigm";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsigmOp.getA());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
@@ -365,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
|
||||
vaddJson["rs1"] = 1;
|
||||
vaddJson["rs2"] = 2;
|
||||
vaddJson["offset"] = createEmptyOffset();
|
||||
vaddJson["len"] = 32 * outChannels;
|
||||
emitInstruction(std::move(vaddJson));
|
||||
}
|
||||
}
|
||||
@@ -506,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
||||
coreCodeGen.codeGenVMaxOp(vmaxOp);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
coreCodeGen.codeGenVVAddOp(vvaddOp);
|
||||
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||
coreCodeGen.codeGenVVSubOp(vvsubOp);
|
||||
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||
coreCodeGen.codeGenVVMulOp(vvmulOp);
|
||||
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
|
||||
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
|
||||
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||
coreCodeGen.codeGenVAvgOp(vavgOp);
|
||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
|
||||
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||
else if (isa<pim::PimSumOp>(op)) {
|
||||
// TODO: Implement somehow?
|
||||
op.emitWarning("Operation is not yet supported in code generation");
|
||||
continue;
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include "Common/ValueMap.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -91,9 +90,15 @@ public:
|
||||
template <typename MVMTy>
|
||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
||||
|
||||
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
|
||||
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
|
||||
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
|
||||
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
|
||||
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
|
||||
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
|
||||
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
|
||||
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||
};
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerUtils"
|
||||
|
||||
@@ -48,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimHostVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim host verified"));
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
|
||||
@@ -2,23 +2,27 @@ set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
|
||||
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.cpp
|
||||
NN/Pooling.cpp
|
||||
NN/ReduceMean.cpp
|
||||
Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Tensor/RemoveUnusedHelperOps.cpp
|
||||
add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/ReduceMean.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
Utils/SpatialReducer.cpp
|
||||
Utils/WeightSubdivider.cpp
|
||||
Utils/AnnotateReplication.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
ONNXToSpatialCommon.cpp
|
||||
Common.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCompilerOptions
|
||||
OMONNXOps
|
||||
@@ -26,5 +30,5 @@ add_onnx_mlir_library(OMONNXToSpatial
|
||||
OMPimCommon
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
${PIM_GENERATED_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "ONNXToSpatialCommon.hpp"
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -1,427 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
// Simple case: if we have only one input, just return it
|
||||
if (valuesToReduce.size() == 1)
|
||||
return valuesToReduce[0];
|
||||
|
||||
if (preprocess) {
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = preprocess(valToReduce);
|
||||
}
|
||||
}
|
||||
|
||||
// It is possible that `valuesToReduce` contains two entries for the same
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
// if (valToReduce.getDefiningOp()) {
|
||||
// // If the value is defined by an operation, we take the parent
|
||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
||||
// } else {
|
||||
// // Otherwise it is a block argument,
|
||||
// computeOp->getBlock()->getParentOp();
|
||||
// }
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp);
|
||||
|
||||
if (it != lastValueForCompute.end()) {
|
||||
// If we have already seen this computeOp, apply the reduction
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
// Now, reconstruct from the map the valuesToReduce list
|
||||
valuesToReduce.clear();
|
||||
valuesToReduce.reserve(lastValueForCompute.size());
|
||||
for (auto& entry : lastValueForCompute)
|
||||
valuesToReduce.push_back(entry.second);
|
||||
|
||||
Location loc = valuesToReduce[0].getLoc();
|
||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
||||
|
||||
// Recursive algorithm to reduce the inputs to a single one:
|
||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
||||
// the valuesToReduce list which becomes half the size.
|
||||
// - Repeat until there is only one input left.
|
||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
||||
while (valuesToReduceRef.size() > 1) {
|
||||
SmallVector<Value> nextValuesToReduce;
|
||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
||||
auto firstValue = valuesToReduceRef[i];
|
||||
auto secondValue = valuesToReduceRef[i + 1];
|
||||
|
||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
||||
|
||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
||||
std::swap(firstValue, secondValue);
|
||||
std::swap(firstCompute, secondCompute);
|
||||
}
|
||||
|
||||
// 1. Add a channel before the first computeOp
|
||||
rewriter.setInsertionPoint(firstCompute);
|
||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Add a sendOp after the first value
|
||||
rewriter.setInsertionPointAfterValue(firstValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
Value reduced = reduce(receivedValue, secondValue);
|
||||
|
||||
nextValuesToReduce.push_back(reduced);
|
||||
}
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (valuesToReduceRef.size() % 2 == 1)
|
||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
}
|
||||
|
||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalValue = valuesToReduceRef[0];
|
||||
|
||||
if (postprocess) {
|
||||
rewriter.setInsertionPointAfterValue(finalValue);
|
||||
finalValue = postprocess(finalValue);
|
||||
}
|
||||
|
||||
return finalValue;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
bool hasPostProcessPoolingWindow() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||
loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
PoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = getImageHeight(xShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t output_h = getImageHeight(yShape);
|
||||
size_t output_w = getImageWidth(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
||||
// Suppose that the input tensor is produced by concatenating the results of
|
||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt =
|
||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
||||
|
||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
||||
// can do better.
|
||||
// If some input tiles come from the func.func operands, load
|
||||
// them into a computeOp and yield them
|
||||
for (size_t t = 0; t < channelTileCount; t++) {
|
||||
for (size_t x = 0; x < input_w; x++) {
|
||||
for (size_t y = 0; y < input_h; y++) {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||
tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2: Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[channelTile][x][y]
|
||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
|
||||
// List of values to pool for each output pixel
|
||||
SmallVector<Value> valuesToPool;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
|
||||
// Each output pixel tile is computed by pooling a window of input
|
||||
// pixel tiles
|
||||
valuesToPool.clear();
|
||||
size_t tilesSkippedByPadding = 0;
|
||||
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
|
||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
tilesSkippedByPadding++;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value inputTile = inputTiles[outTile][inX][inY];
|
||||
|
||||
Value valueToPool;
|
||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
|
||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
||||
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
||||
}
|
||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
if (failed(sendOpOpt)) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
}
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
valueToPool = sendOp.getData();
|
||||
}
|
||||
else {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
}
|
||||
|
||||
valuesToPool.push_back(valueToPool);
|
||||
}
|
||||
}
|
||||
|
||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
||||
// assert(computeOpsForPooling.size() != 1 &&
|
||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
||||
// this " "should have been simplified earlier???");
|
||||
|
||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
||||
postProcessFn = [&](const Value prevFinalRes) {
|
||||
return postProcessPoolingWindow(
|
||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
};
|
||||
}
|
||||
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
// Send this value through a channel, and receive it in the
|
||||
// `func.func`. During lowering, we will need to "move it" into the
|
||||
// users computeOps
|
||||
auto computeOpOfReduced =
|
||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel =
|
||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
||||
// them!
|
||||
|
||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
auto outputTile = outputTiles[outTile][outX][outY];
|
||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
||||
if (!outputTileProducer) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
}
|
||||
|
||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
|
||||
rewriter.replaceOp(poolOp, outputImage);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
||||
ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
|
||||
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
|
||||
def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
@@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat<
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
ONNXGemmOp $A, $B,
|
||||
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
@@ -12,7 +13,7 @@
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
@@ -56,6 +57,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
@@ -74,7 +76,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
|
||||
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
|
||||
});
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
@@ -83,13 +87,15 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populatePoolTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateReduceMeanConversionPattern(patterns, ctx);
|
||||
@@ -113,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
|
||||
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
|
||||
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
|
||||
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
if (failed(cleanupPM.run(moduleOp)))
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
|
||||
+4
-2
@@ -7,13 +7,15 @@ namespace onnx_mlir {
|
||||
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
+3
-11
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB) {
|
||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
biasType,
|
||||
b,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
if (hasB)
|
||||
gemmC = b;
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
+69
-11
@@ -11,7 +11,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -23,6 +23,38 @@ namespace {
|
||||
|
||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
static FailureOr<Value> materializeScaledConstantTensor(Value value,
|
||||
float factor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (factor == 1.0f)
|
||||
return value;
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> scaledValues;
|
||||
scaledValues.reserve(denseAttr.getNumElements());
|
||||
APFloat scale(factor);
|
||||
bool hadFailure = false;
|
||||
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||
APFloat scaledValue(originalValue);
|
||||
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||
hadFailure = true;
|
||||
scaledValues.push_back(std::move(scaledValue));
|
||||
}
|
||||
if (hadFailure)
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
||||
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
alpha = 1.0f;
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
||||
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
beta = 1.0f;
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
@@ -0,0 +1,108 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||
return failure();
|
||||
|
||||
const int64_t batch = rhsType.getDimSize(0);
|
||||
const int64_t k = rhsType.getDimSize(1);
|
||||
const int64_t n = rhsType.getDimSize(2);
|
||||
const int64_t m = lhsType.getDimSize(0);
|
||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||
|| outType.getDimSize(2) != n)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||
|
||||
Value lhsTransposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(batch * n);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmRowType,
|
||||
rhsRow,
|
||||
lhsTransposed,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
gemmRows.push_back(gemmOp.getY());
|
||||
}
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (Value gemmRow : gemmRows)
|
||||
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||
}
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileType.getRank());
|
||||
for (int64_t dimSize : tileType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static Value
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
struct PoolToSpatialCompute;
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = poolOp.getLoc();
|
||||
Value x = adaptor.getX();
|
||||
|
||||
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
||||
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
||||
if (xType.getRank() != 4 || outType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
||||
|
||||
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
||||
if (!kernelAttr || kernelAttr.size() != 2)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t channels = xType.getDimSize(1);
|
||||
const int64_t inputHeight = xType.getDimSize(2);
|
||||
const int64_t inputWidth = xType.getDimSize(3);
|
||||
const int64_t outputHeight = outType.getDimSize(2);
|
||||
const int64_t outputWidth = outType.getDimSize(3);
|
||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||
|
||||
int64_t padTop = 0;
|
||||
int64_t padLeft = 0;
|
||||
int64_t padBottom = 0;
|
||||
int64_t padRight = 0;
|
||||
|
||||
if (auto padsAttr = poolOp.getPads()) {
|
||||
if (padsAttr->size() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||
padTop = getI64(*padsAttr, 0);
|
||||
padLeft = getI64(*padsAttr, 1);
|
||||
padBottom = getI64(*padsAttr, 2);
|
||||
padRight = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
StringRef autoPad = poolOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
||||
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padTop = totalPadH / 2;
|
||||
padBottom = totalPadH - padTop;
|
||||
padLeft = totalPadW / 2;
|
||||
padRight = totalPadW - padLeft;
|
||||
}
|
||||
else {
|
||||
padBottom = totalPadH / 2;
|
||||
padTop = totalPadH - padBottom;
|
||||
padRight = totalPadW / 2;
|
||||
padLeft = totalPadW - padRight;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
||||
}
|
||||
}
|
||||
|
||||
(void) padBottom;
|
||||
(void) padRight;
|
||||
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
computeBlock->addArgument(xType, loc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
Value input = computeBlock->getArgument(0);
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batchSize);
|
||||
|
||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(outputHeight);
|
||||
|
||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||
SmallVector<Value> rowPixels;
|
||||
rowPixels.reserve(outputWidth);
|
||||
|
||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||
SmallVector<Value> outputChannelTiles;
|
||||
outputChannelTiles.reserve(channelTileCount);
|
||||
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> windowValues;
|
||||
windowValues.reserve(kernelHeight * kernelWidth);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||
if (inH < 0 || inH >= inputHeight)
|
||||
continue;
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||
if (inW < 0 || inW >= inputWidth)
|
||||
continue;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
rewriter.getIndexAttr(inH),
|
||||
rewriter.getIndexAttr(inW)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
windowValues.push_back(windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindow);
|
||||
}
|
||||
|
||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||
}
|
||||
|
||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||
}
|
||||
|
||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||
}
|
||||
|
||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||
|
||||
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
||||
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
+4
-4
@@ -1,15 +1,15 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
||||
ONNXConcatToTensorConcat(MLIRContext* ctx)
|
||||
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
Concat(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
||||
};
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConcatToTensorConcat>(ctx);
|
||||
patterns.insert<Concat>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(sourceIdx);
|
||||
while (sourceProduct != resultProduct) {
|
||||
if (sourceProduct > resultProduct)
|
||||
return false;
|
||||
sourceIdx++;
|
||||
if (sourceIdx >= sourceShape.size())
|
||||
return false;
|
||||
group.push_back(sourceIdx);
|
||||
sourceProduct *= sourceShape[sourceIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(resultIdx);
|
||||
while (resultProduct != sourceProduct) {
|
||||
if (resultProduct > sourceProduct)
|
||||
return false;
|
||||
resultIdx++;
|
||||
if (resultIdx >= resultShape.size())
|
||||
return false;
|
||||
group.push_back(resultIdx);
|
||||
resultProduct *= resultShape[resultIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||
ONNXReshapeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType) {
|
||||
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<Reshape>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,35 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename OpTy, typename OpAdaptorTy>
|
||||
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
|
||||
RemoveUnusedHelperOps(MLIRContext* ctx)
|
||||
: OpRewritePattern<OpTy>(ctx) {}
|
||||
|
||||
void initialize() { this->setHasBoundedRewriteRecursion(); }
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
|
||||
if (op.getResult().use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
|
||||
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
|
||||
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <queue>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
add_onnx_mlir_rewriter(SpatialToGraphviz)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToGraphviz
|
||||
add_pim_library(OMSpatialToGraphviz
|
||||
SpatialToGraphviz.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
OMONNXOps
|
||||
SpatialOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
${PIM_GENERATED_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
@@ -2,19 +2,22 @@ set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
|
||||
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(SpatialToPimIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToPim
|
||||
add_pim_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
SpatialToPimCommon.cpp
|
||||
Common.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
SpatialToPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
${PIM_GENERATED_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "SpatialToPimCommon.hpp"
|
||||
#include "Common.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVAddOp : Pat<
|
||||
def spatToPimVVAddOp : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVAddOp $a, $b,
|
||||
(PimVVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMulOp : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMaxOp : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
@@ -18,9 +19,9 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
@@ -56,10 +57,8 @@ private:
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void addReceiveOps(Value channelSourceOp,
|
||||
spatial::SpatChannelNewOp& channel,
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void
|
||||
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
@@ -79,8 +78,31 @@ private:
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
@@ -175,35 +197,67 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||
|
||||
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||
// If this result has no uses, then just skip it
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
|
||||
/*
|
||||
* Here we assume that ReturnOp are only reachable by the following patterns:
|
||||
*
|
||||
* 1)
|
||||
* %0 = spat.compute([...])
|
||||
* [%0 has one user, which is a ConcatOp]
|
||||
* %1 = tensor.concat(%0)
|
||||
* [%1 has one user, which is a ReturnOp]
|
||||
* return %1
|
||||
*
|
||||
* 2)
|
||||
* %0 = spat.compute([...])
|
||||
* [%0 has one user, which is a ReturnOp]
|
||||
* return %0
|
||||
*
|
||||
* If the IR is like 2), then we can store the tensor to the output global memory location
|
||||
*/
|
||||
auto resultUses = result.getUses();
|
||||
auto numResultUses = rangeLength(resultUses);
|
||||
if (numResultUses == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isChannelUseChainOp(resultUser)) {
|
||||
SmallVector<Operation*> returnChain;
|
||||
Value chainedValue = result;
|
||||
Operation* chainUser = resultUser;
|
||||
|
||||
while (isChannelUseChainOp(chainUser)) {
|
||||
returnChain.push_back(chainUser);
|
||||
auto chainUses = chainUser->getResult(0).getUses();
|
||||
if (rangeLength(chainUses) != 1)
|
||||
break;
|
||||
chainedValue = chainUser->getResult(0);
|
||||
chainUser = chainUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(chainUser)) {
|
||||
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
IRMapping mapping;
|
||||
mapping.map(result, yieldValue);
|
||||
|
||||
Value storedValue = yieldValue;
|
||||
for (Operation* op : returnChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
storedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t offset = 0;
|
||||
@@ -475,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
receivedValue =
|
||||
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
else
|
||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
receivedValue =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||
|
||||
Value replacementValue = receivedValue;
|
||||
if (consumerValue != channelSourceOp) {
|
||||
@@ -493,6 +548,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -502,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||
}
|
||||
|
||||
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
||||
assert(replacementValue.getType() == blockArg.getType()
|
||||
&& "Replayed channel use chain must match block argument type");
|
||||
blockArg.replaceAllUsesWith(replacementValue);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
#pragma once
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace spatial {
|
||||
|
||||
// TODO: Add here eventual patterns
|
||||
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,10 +3,12 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
|
||||
add_onnx_mlir_library(PimOps
|
||||
add_pim_library(PimOps
|
||||
PimOps.hpp
|
||||
PimOps.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
OMPimIncGen
|
||||
|
||||
|
||||
+101
-9
@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise addition: c = a + b
|
||||
}];
|
||||
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise subtraction: c = a - b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise multiplication: c = a * b
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
|
||||
let description = [{
|
||||
Element-wise max: c = max(a, b)
|
||||
}];
|
||||
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutBufMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Dot product: c = dot(a, b)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $b,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
|
||||
Average all elements into a single one
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $dividend,
|
||||
PimTensor: $divisor,
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
|
||||
);
|
||||
}
|
||||
|
||||
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise exp: c = exp(a)
|
||||
Element-wise tanh activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor: $a,
|
||||
PimTensor: $outBuf
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor: $outRes
|
||||
);
|
||||
}
|
||||
|
||||
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
|
||||
let description = [{
|
||||
Element-wise sigmoid activation
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@@ -388,4 +480,4 @@ def PimHaltOp: PimOp<"halt", [Terminator]> {
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
#endif // PIM_DIALECT_H
|
||||
|
||||
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
|
||||
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
|
||||
}
|
||||
|
||||
POPULATE_DEPENDENCIES(PimVMaxOp)
|
||||
POPULATE_DEPENDENCIES(PimVVDMulOp)
|
||||
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
|
||||
POPULATE_DEPENDENCIES(PimSumOp)
|
||||
POPULATE_DEPENDENCIES(PimVSDivOp)
|
||||
POPULATE_DEPENDENCIES(PimVAvgOp)
|
||||
POPULATE_DEPENDENCIES(PimVReluOp)
|
||||
POPULATE_DEPENDENCIES(PimVExpOp)
|
||||
POPULATE_DEPENDENCIES(PimVTanhOp)
|
||||
POPULATE_DEPENDENCIES(PimVSigmOp)
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,13 +2,15 @@ set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMPimBufferization
|
||||
add_pim_library(OMPimBufferization
|
||||
PimBufferizationPass.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
Common.hpp
|
||||
Common.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
PimBufferizationIncGen
|
||||
|
||||
@@ -17,5 +19,5 @@ add_onnx_mlir_library(OMPimBufferization
|
||||
PimOps
|
||||
|
||||
ACCEL_INCLUDE_DIRS PRIVATE
|
||||
${PIM_INCLUDE_PATH}
|
||||
${PIM_GENERATED_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -12,6 +13,26 @@ using namespace bufferization;
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
}
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vaddOp = cast<PimVAddOp>(op);
|
||||
auto binaryOp = cast<OpTy>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
||||
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
|
||||
if (failed(aOpt))
|
||||
return failure();
|
||||
|
||||
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
|
||||
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
|
||||
if (failed(bOpt))
|
||||
return failure();
|
||||
|
||||
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
|
||||
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
|
||||
if (failed(outBufOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
|
||||
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
|
||||
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
|
||||
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
|
||||
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
|
||||
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
|
||||
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
add_onnx_mlir_dialect(Spatial spat)
|
||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
|
||||
add_onnx_mlir_library(SpatialOps
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
OMONNXIncGen
|
||||
OMSpatialIncGen
|
||||
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
OMMlirDialects
|
||||
OMONNXOps
|
||||
OMPimCompilerOptions
|
||||
PimOps
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
|
||||
@@ -37,8 +37,76 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getDstOut();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto secondUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Operation* otherUser = nullptr;
|
||||
if (firstUser == op)
|
||||
otherUser = secondUser;
|
||||
else if (secondUser == op)
|
||||
otherUser = firstUser;
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return otherUser;
|
||||
}
|
||||
|
||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
// This function requires the existence of ChannelNewOp and the other
|
||||
@@ -49,7 +117,7 @@ llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsR
|
||||
if (precomputedOtherCoreId)
|
||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
||||
if (failed(notOpUserOpt))
|
||||
return failure();
|
||||
Operation* notOpUser = *notOpUserOpt;
|
||||
@@ -119,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(*memref);
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
@@ -412,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
@@ -420,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
|
||||
|
||||
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
|
||||
|
||||
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
// Create a new bufferizable op interface for the apply filters operation.
|
||||
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
|
||||
@@ -509,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
|
||||
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
@@ -521,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
add_pim_library(OMPimPasses
|
||||
CountInstructionPass.cpp
|
||||
MessagePass.cpp
|
||||
PimConstantFolding/Common.cpp
|
||||
PimConstantFolding/Patterns/Constant.cpp
|
||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
||||
PimConstantFolding/Patterns/Subview.cpp
|
||||
PimMaterializeConstantsPass.cpp
|
||||
PimVerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
OMPimCommon
|
||||
)
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -17,7 +18,10 @@ struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
|
||||
: message(message) {}
|
||||
MessagePass(const MessagePass& pass) {}
|
||||
|
||||
void runOnOperation() final { showCompilePhase(message); }
|
||||
void runOnOperation() final {
|
||||
llvm::outs() << message << "\n";
|
||||
llvm::outs().flush();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string message;
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "Common.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment) {
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
mlir::Location loc,
|
||||
mlir::MemRefType globalType,
|
||||
mlir::DenseElementsAttr denseAttr,
|
||||
llvm::StringRef nameStem,
|
||||
mlir::IntegerAttr alignment = {});
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
llvm::ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+105
-236
@@ -1,19 +1,11 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -21,55 +13,14 @@
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment = {}) {
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
@@ -126,13 +77,6 @@ static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr den
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
if (!mapOp.getInputs().empty())
|
||||
return failure();
|
||||
@@ -176,151 +120,13 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
rewriter.setInsertionPoint(coreOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth == 0)
|
||||
return failure();
|
||||
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
Value source;
|
||||
SmallVector<int64_t> sourceShape;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
static int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
@@ -473,17 +279,15 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
MemRefType globalType = resultType;
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||
transposeOp.getLoc(),
|
||||
globalType,
|
||||
resultType,
|
||||
*transposedAttr,
|
||||
sourceGlobal.getName().str() + "__folded_transpose",
|
||||
sourceGlobal.getAlignmentAttr());
|
||||
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
@@ -578,41 +382,106 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
||||
context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
}
|
||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
}
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,223 @@
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (size != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : dst,
|
||||
splitSrc ? srcSubview->source : src,
|
||||
dstByteOffset,
|
||||
srcByteOffset,
|
||||
sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
copyOp.getDst(),
|
||||
copyOp.getSrc(),
|
||||
copyOp.getDstOffset(),
|
||||
copyOp.getSrcOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
copyOp.getDeviceDst(),
|
||||
copyOp.getHostSrc(),
|
||||
copyOp.getDeviceDstOffset(),
|
||||
copyOp.getHostSrcOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType =
|
||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,53 @@
|
||||
#include "Patterns.hpp"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
|
||||
populateConstantFoldingConstantPatterns(owningPatterns);
|
||||
populateConstantFoldingSubviewPatterns(owningPatterns);
|
||||
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,135 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
static int64_t getValueSizeInBytes(Value value) {
|
||||
auto type = dyn_cast<ShapedType>(value.getType());
|
||||
if (!type || !type.hasStaticShape())
|
||||
return -1;
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct PimMaterializeConstantsPass
|
||||
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
OpBuilder rewriter(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op.getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op.emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(&op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(totalBytes)));
|
||||
|
||||
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||
operand.set(hostToDevCopy.getResult());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim3_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
|
||||
return std::make_unique<PimMaterializeConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@@ -26,49 +27,32 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value) {
|
||||
while (true) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress))
|
||||
return false;
|
||||
}
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-host-pass"; }
|
||||
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that no runtime host-side code remains in bufferized PIM IR";
|
||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||
}
|
||||
|
||||
PimHostVerificationPass() {}
|
||||
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
|
||||
PimVerificationPass() {}
|
||||
PimVerificationPass(const PimVerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -80,7 +64,7 @@ struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationP
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)))
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
@@ -139,14 +123,49 @@ private:
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
if (!isHostAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : coreOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlySource(op, subviewOp.getSource());
|
||||
@@ -160,16 +179,16 @@ private:
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
if (isHostAddressableValue(source))
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that still requires host-side execution");
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createBufferizePimPass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimHostVerificationPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# Validation Operations
|
||||
|
||||
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
|
||||
|
||||
Generated tests can be regenerated with:
|
||||
```
|
||||
python3 validation/operations/gen_tests.py
|
||||
```
|
||||
|
||||
## Conv
|
||||
|
||||
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|
||||
|------|-----------|-------|--------|--------|--------|---------|------|-------|
|
||||
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
|
||||
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
|
||||
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
|
||||
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
|
||||
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
|
||||
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
|
||||
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
|
||||
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
|
||||
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
|
||||
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
||||
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||
|
||||
## Pool
|
||||
|
||||
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|
||||
|------|-----------|-------|--------|--------|--------|---------|-------|
|
||||
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
|
||||
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
|
||||
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
|
||||
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
|
||||
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
|
||||
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
|
||||
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
|
||||
|
||||
## Gemm
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
|
||||
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
|
||||
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
|
||||
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
|
||||
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
|
||||
|
||||
## Gemv
|
||||
|
||||
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|
||||
|------|-----------|-------|------------|--------|------|-------|
|
||||
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
|
||||
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
|
||||
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
|
||||
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
|
||||
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import helper, TensorProto, numpy_helper
|
||||
from pathlib import Path
|
||||
|
||||
OPERATIONS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def save_model(model, directory, filename):
|
||||
"""Save an ONNX model, creating the directory if needed."""
|
||||
d = OPERATIONS_DIR / directory
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
path = d / filename
|
||||
onnx.checker.check_model(model)
|
||||
onnx.save(model, str(path))
|
||||
print(f" {path.relative_to(OPERATIONS_DIR)}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GEMM tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gemm_non_square():
|
||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||
B, K, N = 4, 128, 64
|
||||
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
|
||||
|
||||
|
||||
def gemm_with_bias():
|
||||
"""GEMM with bias: Y = A @ W + C."""
|
||||
B, K, N = 4, 128, 128
|
||||
rng = np.random.default_rng(43)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
|
||||
|
||||
|
||||
def gemm_transB():
|
||||
"""GEMM with transB=1: Y = A @ W^T."""
|
||||
B, K, N = 4, 128, 64
|
||||
rng = np.random.default_rng(44)
|
||||
# W stored as [N, K], transposed during computation
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
|
||||
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/transB", "gemm_transB.onnx")
|
||||
|
||||
|
||||
def gemm_alpha_beta():
|
||||
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
|
||||
B, K, N = 4, 64, 64
|
||||
rng = np.random.default_rng(45)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
|
||||
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
|
||||
|
||||
|
||||
def gemm_small():
|
||||
"""Small GEMM: [2, 8] @ [8, 4]."""
|
||||
B, K, N = 2, 8, 4
|
||||
rng = np.random.default_rng(46)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/small", "gemm_small.onnx")
|
||||
|
||||
|
||||
def gemm_large():
|
||||
"""Larger GEMM: [8, 256] @ [256, 128]."""
|
||||
B, K, N = 8, 256, 128
|
||||
rng = np.random.default_rng(47)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/large", "gemm_large.onnx")
|
||||
|
||||
|
||||
def gemm_transB_with_bias():
|
||||
"""GEMM with transB and bias: Y = A @ W^T + C."""
|
||||
B, K, N = 4, 128, 64
|
||||
rng = np.random.default_rng(48)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
|
||||
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conv tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def conv_3x3_kernel():
|
||||
"""Conv with 3x3 kernel, no padding."""
|
||||
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
|
||||
|
||||
|
||||
def conv_stride2():
|
||||
"""Conv with 3x3 kernel and stride 2."""
|
||||
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
|
||||
|
||||
|
||||
def conv_multi_channel():
|
||||
"""Conv with multiple input and output channels."""
|
||||
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
|
||||
|
||||
|
||||
def conv_1x1():
|
||||
"""1x1 pointwise convolution (channel mixing)."""
|
||||
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
|
||||
|
||||
|
||||
def conv_same_padding_3x3():
|
||||
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
|
||||
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
|
||||
|
||||
|
||||
def conv_explicit_padding():
|
||||
"""Conv 3x3 with explicit asymmetric padding."""
|
||||
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
|
||||
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
|
||||
|
||||
|
||||
def conv_with_bias_3x3():
|
||||
"""Conv 3x3 with bias."""
|
||||
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
|
||||
rng = np.random.default_rng(56)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
|
||||
|
||||
|
||||
def conv_batch_2():
|
||||
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
|
||||
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
|
||||
rng = np.random.default_rng(57)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
|
||||
|
||||
|
||||
def conv_large_spatial():
|
||||
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pooling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def maxpool_basic():
|
||||
"""MaxPool 2x2 with stride 1."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
|
||||
|
||||
|
||||
def maxpool_stride2_multichannel():
|
||||
"""MaxPool 2x2 with stride 2 on multiple channels."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
|
||||
|
||||
|
||||
def maxpool_same_upper():
|
||||
"""MaxPool 3x3 with SAME_UPPER padding."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
|
||||
|
||||
|
||||
def avgpool_basic():
|
||||
"""AveragePool 2x2 with stride 1."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
|
||||
|
||||
|
||||
def avgpool_explicit_padding():
|
||||
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
|
||||
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
|
||||
|
||||
|
||||
def avgpool_include_pad():
|
||||
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
|
||||
node = helper.make_node("AveragePool", ["X"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
|
||||
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
|
||||
|
||||
|
||||
def maxpool_after_conv():
|
||||
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
|
||||
rng = np.random.default_rng(59)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating GEMM tests:")
|
||||
gemm_non_square()
|
||||
gemm_with_bias()
|
||||
gemm_transB()
|
||||
gemm_alpha_beta()
|
||||
gemm_small()
|
||||
gemm_large()
|
||||
gemm_transB_with_bias()
|
||||
|
||||
print("\nGenerating Conv tests:")
|
||||
conv_3x3_kernel()
|
||||
conv_stride2()
|
||||
conv_multi_channel()
|
||||
conv_1x1()
|
||||
conv_same_padding_3x3()
|
||||
conv_explicit_padding()
|
||||
conv_with_bias_3x3()
|
||||
conv_batch_2()
|
||||
conv_large_spatial()
|
||||
|
||||
print("\nGenerating Pooling tests:")
|
||||
maxpool_basic()
|
||||
maxpool_stride2_multichannel()
|
||||
maxpool_same_upper()
|
||||
avgpool_basic()
|
||||
avgpool_explicit_padding()
|
||||
avgpool_include_pad()
|
||||
maxpool_after_conv()
|
||||
|
||||
print("\nDone.")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user