Compare commits

...

34 Commits

Author SHA1 Message Date
NiccoloN
0478d979ff minor CI fix
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h8m44s
2026-03-23 20:45:10 +01:00
NiccoloN
da01e6d697 refactor Pim passes layout 2026-03-23 20:14:59 +01:00
NiccoloN
f869925b64 remove old unused stuff 2026-03-23 20:00:09 +01:00
NiccoloN
f2d593f749 minor CI fixes 2026-03-23 19:43:50 +01:00
NiccoloN
a1b29dffe0 fix CI (hopefully)
Some checks failed
Validate Operations / validate-operations (push) Failing after 1h2m33s
2026-03-23 19:27:53 +01:00
NiccoloN
661170a9aa reimplement pool lowering
add pool validation
align PIM ops/codegen/parser with the ISA
move constant materialization to MLIR
rename the PIM verification/materialization passes
better folded-constant handling
2026-03-23 19:14:50 +01:00
NiccoloN
461bdd808d replace helper-op cleanup with canonicalization
clean up PIM pattern naming
remove unused ValueMap.hpp
2026-03-23 17:13:54 +01:00
NiccoloN
50c545539b clean up PIM CMake
update README.md
2026-03-23 16:39:14 +01:00
NiccoloN
11916a2595 refactor Pim constant folding pass
share contiguous address resolution in PimCommon
group patterns in subdir for each pass with pattern files
2026-03-23 15:36:58 +01:00
NiccoloN
670d6ce94f extend operation support for conv and gemm
add more tests in validation
2026-03-23 14:46:08 +01:00
NiccoloN
2676f2c7ef fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 1m7s
Validate Operations / build-mlir-cache (push) Successful in 2h1m3s
Validate Operations / validate (push) Failing after 56m7s
2026-03-23 14:30:02 +01:00
NiccoloN
45342190bb fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 1m13s
Validate Operations / build-mlir-cache (push) Successful in 3m5s
Validate Operations / validate (push) Failing after 3m5s
2026-03-23 13:24:00 +01:00
NiccoloN
f629e0d99f fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 1m12s
Validate Operations / build-mlir-cache (push) Successful in 3m13s
Validate Operations / validate (push) Failing after 12m29s
2026-03-23 11:38:54 +01:00
NiccoloN
568529ea5f fix batched conv
Some checks failed
Validate Operations / config (push) Successful in 1m3s
Validate Operations / build-mlir-cache (push) Successful in 3m40s
Validate Operations / validate (push) Failing after 2m37s
2026-03-20 22:00:46 +01:00
NiccoloN
ca2e1645bb simple convolutions now work :) 2026-03-20 21:17:02 +01:00
NiccoloN
6933804003 constant fold linalg.map (generated from tensor.pad for padding)
refactor pim helpers in PimCommon
2026-03-20 20:51:20 +01:00
NiccoloN
dbe646ac0d fix gemm segfault
Some checks failed
Validate Operations / config (push) Successful in 1m27s
Validate Operations / build-mlir-cache (push) Successful in 2h0m37s
Validate Operations / validate (push) Failing after 3m41s
print exit signals on validation failure
2026-03-20 14:00:16 +01:00
NiccoloN
bb6dcd38a3 replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
2026-03-20 13:30:53 +01:00
NiccoloN
916a09414c add validation artifacts cleanup 2026-03-20 13:15:08 +01:00
NiccoloN
db3f52a647 conv now lowers correctly down to bufferized pim 2026-03-20 12:55:09 +01:00
NiccoloN
6e1de865bb add constant folding and verification pass for pim host operations
better validation scripts output
big refactors
2026-03-20 12:08:12 +01:00
NiccoloN
4e50e056e3 replace old convolution support in spatial (WIP)
Some checks failed
Validate Operations / config (push) Successful in 1m46s
Validate Operations / build-mlir-cache (push) Successful in 3m25s
Validate Operations / validate (push) Failing after 2m42s
2026-03-13 17:46:10 +01:00
NiccoloN
771b44a2ed fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 53s
Validate Operations / validate (push) Has been cancelled
Validate Operations / build-mlir-cache (push) Has been cancelled
2026-03-11 15:41:26 +01:00
NiccoloN
7ce1d2b34d fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 52s
Validate Operations / build-mlir-cache (push) Successful in 2h0m25s
Validate Operations / validate (push) Failing after 2m26s
2026-03-10 16:12:26 +01:00
NiccoloN
584ca0b3c2 fix CI (hopefully)
Some checks failed
Validate Operations / config (push) Successful in 1m7s
Validate Operations / build-mlir-cache (push) Failing after 2m11s
Validate Operations / validate (push) Has been skipped
2026-03-09 14:30:41 +01:00
NiccoloN
1348bb1c97 generic gemm now works :) 2026-03-06 18:23:27 +01:00
NiccoloN
825188cc89 Merge remote-tracking branch 'origin/main' 2026-03-06 15:44:30 +01:00
NiccoloN
7202a4317d add free disk space step to CI 2026-03-06 15:44:23 +01:00
ilgeco
d4efa64b96 pim-simulator auto-format 2026-03-04 20:00:01 +01:00
ilgeco
fef26cee9a pim-simulator unwrap on failed json parsing 2026-03-04 19:59:16 +01:00
ilgeco
29febb2bfd pim-simulator dump instruction 2026-03-04 19:57:24 +01:00
ilgeco
f24a60bfcd pim-simulato trace end of lmv 2026-03-04 19:56:59 +01:00
ilgeco
91ef6d9bc3 pim-simulator dump inst function 2026-03-04 19:56:30 +01:00
NiccoloN
8ee1e5ece8 implement mem copy codgen (lmv)
add more gemv/gemm tests
refactor
2026-03-04 18:04:48 +01:00
140 changed files with 5455 additions and 4526 deletions

View File

@@ -0,0 +1,54 @@
name: Prepare MLIR Cache
description: Restore or build the cached MLIR toolchain.
inputs:
llvm-commit:
description: LLVM commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore MLIR Cache
id: restore-mlir-cache
uses: actions/cache/restore@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Clone LLVM
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
cmake -S onnx-mlir/llvm-project/llvm -B onnx-mlir/llvm-project/build -G Ninja \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build onnx-mlir/llvm-project/build
- name: Save MLIR Cache
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}

View File

@@ -0,0 +1,45 @@
name: Prepare Protobuf Cache
description: Restore or build the cached Protobuf installation.
inputs:
protobuf-ref:
description: Protobuf tag or commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore Protobuf Cache
id: restore-protobuf-cache
uses: actions/cache/restore@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-ref }}
- name: Install Protobuf
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --depth 1 --branch ${{ inputs.protobuf-ref }} https://github.com/protocolbuffers/protobuf protobuf
cmake -S protobuf -B protobuf/build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build protobuf/build
sudo cmake --install protobuf/build
rm -rf protobuf
- name: Save Protobuf Cache
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-ref }}

View File

@@ -0,0 +1,26 @@
name: Restore Raptor Build Cache
description: Restore the cached Raptor build directory for incremental builds.
inputs:
key:
description: Exact cache key to restore.
required: true
restore-keys:
description: Prefixes used to restore the most recent compatible cache.
required: false
outputs:
cache-hit:
description: Whether the exact cache key was restored.
value: ${{ steps.restore-raptor-build-cache.outputs.cache-hit }}
runs:
using: composite
steps:
- name: Restore Raptor Build Cache
id: restore-raptor-build-cache
uses: actions/cache/restore@v4
with:
path: build
key: ${{ inputs.key }}
restore-keys: ${{ inputs.restore-keys }}

View File

@@ -0,0 +1,16 @@
name: Save Raptor Build Cache
description: Save the Raptor build directory after a successful incremental build.
inputs:
key:
description: Cache key used to save the build directory.
required: true
runs:
using: composite
steps:
- name: Save Raptor Build Cache
uses: actions/cache/save@v4
with:
path: build
key: ${{ inputs.key }}

View File

@@ -1,50 +0,0 @@
name: Build MLIR Cache
on:
workflow_call:
inputs:
llvm-commit:
required: true
type: string
jobs:
build-mlir:
runs-on: ubuntu-latest
steps:
- name: Cache MLIR build
id: cache-mlir
uses: actions/cache@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Install build dependencies
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
sudo apt update
sudo apt install -y cmake ninja-build build-essential
- name: Clone LLVM
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
mkdir -p onnx-mlir/llvm-project/build
cd onnx-mlir/llvm-project/build
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF
cmake --build .

View File

@@ -2,64 +2,58 @@ name: Validate Operations
env:
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
PROTOBUF_COMMIT: v34.0
PROTOBUF_REF: v34.0
CMAKE_VERSION: 4.3.0
MOLD_LINKER_FLAGS: -fuse-ld=mold
on:
push:
pull_request:
jobs:
# Expose env vars as outputs so they can be passed to reusable workflows
config:
runs-on: ubuntu-latest
outputs:
llvm-commit: ${{ steps.set-vars.outputs.llvm-commit }}
steps:
- id: set-vars
run: echo "llvm-commit=$LLVM_COMMIT" >> "$GITHUB_OUTPUT"
build-mlir-cache:
needs: config
uses: ./.github/workflows/build_mlir_cache.yml
with:
llvm-commit: ${{ needs.config.outputs.llvm-commit }}
validate:
needs: build-mlir-cache
validate-operations:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
run: |
git clone --depth 1 --recurse-submodules --branch ${GITHUB_REF_NAME} \
https://chef.heaplab.deib.polimi.it/git/${GITHUB_REPOSITORY}.git \
${GITHUB_WORKSPACE}
- name: Install system dependencies
run: |
sudo apt update
sudo apt install -y cmake ninja-build build-essential
sudo apt install -y cmake ninja-build build-essential mold curl ca-certificates
- name: Cache protobuf build
id: cache-protobuf
uses: actions/cache@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ env.PROTOBUF_COMMIT }}
- name: Install protobuf
if: steps.cache-protobuf.outputs.cache-hit != 'true'
- name: Install CMake
run: |
git clone --depth 1 --branch ${{ env.PROTOBUF_COMMIT }} https://github.com/protocolbuffers/protobuf
cd protobuf
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
ninja
sudo ninja install
cd ../..
rm -rf protobuf
ARCH="$(uname -m)"
case "$ARCH" in
x86_64) CMAKE_ARCH="linux-x86_64" ;;
aarch64|arm64) CMAKE_ARCH="linux-aarch64" ;;
*) echo "Unsupported architecture: $ARCH"; exit 1 ;;
esac
curl -fsSL -o /tmp/cmake.sh \
"https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-${CMAKE_ARCH}.sh"
sudo sh /tmp/cmake.sh --skip-license --prefix=/usr/local
rm -f /tmp/cmake.sh
cmake --version
which cmake
- name: Prepare MLIR Cache
uses: ./.github/actions/prepare-mlir-cache
with:
llvm-commit: ${{ env.LLVM_COMMIT }}
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
- name: Prepare Protobuf Cache
uses: ./.github/actions/prepare-protobuf-cache
with:
protobuf-ref: ${{ env.PROTOBUF_REF }}
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
- name: Register installed libraries
run: sudo ldconfig
@@ -75,26 +69,37 @@ jobs:
- name: Install Python dependencies
run: pip install numpy onnx colorama
- name: Restore MLIR cache
uses: actions/cache/restore@v4
- name: Restore Raptor Build Cache
id: restore-raptor-build-cache
uses: ./.github/actions/restore-raptor-build-cache
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ env.LLVM_COMMIT }}
fail-on-cache-miss: true
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
restore-keys: |
raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-
raptor-build-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-
- name: Build Raptor
id: build-raptor
run: |
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir -p build && cd build
cmake .. -G Ninja \
cmake -S . -B build -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
cmake --build .
-DMLIR_DIR=${MLIR_DIR} \
-DCMAKE_EXE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_SHARED_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_MODULE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}"
cmake --build build
- name: Save Raptor Build Cache
if: steps.build-raptor.outcome == 'success' && steps.restore-raptor-build-cache.outputs.cache-hit != 'true'
uses: ./.github/actions/save-raptor-build-cache
with:
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_REF }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
- name: Run validation
run: |
python validate.py \
python validation/validate.py \
--raptor-path build/Debug/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
.idea
.claude
AGENTS.md
build

View File

@@ -29,7 +29,7 @@ Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it with:
to reduce memory usage during linking. You can use it by setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
@@ -38,8 +38,16 @@ to reduce memory usage during linking. You can use it with:
### Raptor
Use the following commands to build Raptor.
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
```
```
git submodule update --init --recursive

View File

@@ -530,6 +530,7 @@ where
let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0];
@@ -700,6 +701,7 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
let local_memory = core.load::<u8>(r1_val, imm_len)?;
let tmp = local_memory[0].to_vec();
core.execute_store(rd_val, tmp.as_slice());
TRACER.lock().unwrap().post_lmv(cores, data);
Ok(InstructionStatus::Completed)
}

View File

@@ -46,6 +46,10 @@ impl Instruction {
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
.unwrap()
}
pub(crate) fn dump(&self) {
eprintln!("\t{}", functor_to_name(self.functor as usize));
}
}
pub type Instructions = Vec<Instruction>;

View File

@@ -71,18 +71,27 @@ pub fn json_to_instruction(
inst_builder,
inst_data_builder,
json,
);
)
.unwrap();
}
macro_rules! json_str {
($json:ident , $value:literal) => {
$json.get($value).context(concat![$value, " field not present"])?.as_str().context(concat![$value, " field not str"])?
$json
.get($value)
.context(concat![$value, " field not present"])?
.as_str()
.context(concat![$value, " field not str"])?
};
}
macro_rules! json_i64 {
($json:ident , $value:literal) => {
$json.get($value).context(concat![$value, " field not present"])?.as_i64().context(concat![$value, " field not i64"])?
$json
.get($value)
.context(concat![$value, " field not present"])?
.as_i64()
.context(concat![$value, " field not i64"])?
};
}
@@ -215,7 +224,21 @@ fn json_to_vvsub(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vvsub", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvsub, inst_data_builder.build());
Ok(())
}
@@ -247,7 +270,21 @@ fn json_to_vvdmul(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vvdmul", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
Ok(())
}
@@ -297,7 +334,21 @@ fn json_to_vavg(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vavg", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vavg, inst_data_builder.build());
Ok(())
}
@@ -349,7 +400,7 @@ fn json_to_vsigm(
json: &Value,
) -> Result<()> {
let json = json.as_object().expect("Not an object");
assert_eq!("vsigmoid", json_str!(json, "op"));
assert_eq!("vsigm", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let len = json_i64!(json, "len") as i32;

View File

@@ -1,7 +1,7 @@
#![allow(unused)]
use crate::{
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
};
pub mod cpu;
pub mod instruction_set;
@@ -131,6 +131,16 @@ impl Executable {
pub fn cpu_mut(&mut self) -> &mut CPU {
&mut self.cpu
}
pub fn dump(&self) {
let core_instructions = &self.core_instructions;
for (i, core_instruction) in core_instructions.iter().enumerate() {
eprintln!("INST OF CORE {}:", i);
for inst in &core_instruction.instructions {
inst.dump();
}
}
}
}
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {

View File

@@ -1123,7 +1123,7 @@ impl Trace {
if prefix == "Pre" {
writeln!(
file,
"Inst: lvm {} {} {} {{ {} {} }}",
"Inst: lmv {} {} {} {{ {} {} }}",
rd, r1, imm_len, offset_select, offset_value
);
} else {
@@ -1141,13 +1141,15 @@ impl Trace {
);
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
let core_memory = core.load::<u8>(r1_val, imm_len).unwrap();
let global_memory = host.load::<u8>(rd_val, imm_len).unwrap();
let core_memory = core
.reserve_load(r1_val, imm_len).unwrap()
.reserve_load(rd_val, imm_len).unwrap()
.execute_load::<u8>().unwrap();
writeln!(file, "{} Memory:", prefix);
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, global_memory[0], 30);
pretty_print::print_slice::<_,f32>(file, core_memory[1], 30);
if prefix == "Post" {
writeln!(file, "\n###############################################\n");

View File

@@ -10,38 +10,61 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
add_subdirectory(Dialect)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_subdirectory(Common)
add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp
Transforms/PimBufferizationPass.cpp
Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
)
set(PIM_COMPILER_INCLUDE_DIRS
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_ACCEL_INCLUDE_DIRS
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_GENERATED_INCLUDE_DIRS
${PIM_INCLUDE_PATH}
)
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
endfunction()
add_subdirectory(Dialect)
add_subdirectory(Common)
add_subdirectory(Pass)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_pim_library(OMPIMAccel
PimAccelerator.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
onnx
OMAccelerator
OMPimCompilerUtils
OMCompilerUtils
OMPimPasses
OMONNXOps
SpatialOps
PimOps
OMONNXToSpatial
OMSpatialToGraphviz
OMSpatialToPIM
OMPIMCommon
OMSpatialToPim
OMPimCommon
OMPimBufferization
MLIRTensorInferTypeOpInterfaceImpl
)

View File

@@ -1,19 +1,13 @@
add_onnx_mlir_library(OMPIMCommon
PIMCommon.cpp
add_pim_library(OMPimCommon
PimCommon.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
onnx
OMPimCompilerUtils
SpatialOps
PimOps
)

View File

@@ -1,93 +0,0 @@
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string dialectsDir = getOutputDir() + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
} // namespace onnx_mlir

View File

@@ -1,24 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
namespace onnx_mlir {
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
} // namespace onnx_mlir

View File

@@ -0,0 +1,302 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
int64_t byteOffset = 0;
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = tiedOperand->get();
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|| llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset};
return failure();
}
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,58 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
} // namespace onnx_mlir

View File

@@ -1,44 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};

View File

@@ -1,44 +1,36 @@
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
add_onnx_mlir_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerOptions
PimCompilerOptions.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_COMPILER_INCLUDE_DIRS}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerOptions
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_ACCEL_INCLUDE_DIRS}
)
add_onnx_mlir_library(OMPimCompilerUtils
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimCodeGen.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_COMPILER_INCLUDE_DIRS}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
OMCompilerPasses
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_ACCEL_INCLUDE_DIRS}
)

View File

@@ -13,17 +13,18 @@
#include <cassert>
#include <cmath>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
MemEntry* PimMemory::gatherMemEntry(Value value) {
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
@@ -31,7 +32,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
return &memEntries.emplace_back(memEntry, value).first;
}
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
@@ -47,8 +48,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!getGlobalOp->hasAttr("weightAlways")) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
@@ -59,7 +60,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
}
});
for (Value arg : funcOp.getArguments())
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
allocateCore(funcOp);
@@ -73,40 +74,44 @@ void PimMemory::allocateCore(Operation* op) {
allocateMemoryForValue(value, memEntry);
}
MemEntry PimMemory::getMemEntry(Value value) const {
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
}
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
while (true) {
auto definingOp = value.getDefiningOp();
if (!definingOp)
break;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
break;
value = tiedOperand->get();
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto source = subviewDefiningOp.getSource();
auto srcShape = source.getType().getShape();
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
value = source;
}
else
break;
llvm_unreachable("Failed to resolve contiguous address");
}
return memEntriesMap.at(value).address;
auto iter = memEntriesMap.find(resolvedAddress->base);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + resolvedAddress->byteOffset;
}
json::Object PimCodeGen::createEmptyOffset() {
@@ -144,15 +149,20 @@ void PimCodeGen::setupRdRs1Rs2(
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::emitMemCopyOp(
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const {
void PimCodeGen::emitMemCopyOp(StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json;
json["op"] = opName;
json["rd"] = 0;
json["rs1"] = 1;
json["size"] = size;
json[sizeFieldName] = size;
json["offset"] = createEmptyOffset();
emitInstruction(std::move(json));
}
@@ -200,6 +210,16 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getDst()),
lmvOp.getDstOffset(),
memory.getValueAddress(lmvOp.getSrc()),
lmvOp.getSrcOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getDst()), receiveOp.getSrcCoreId(), receiveOp.getSize());
@@ -217,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vaddOp.getA());
auto bAddr = memory.getValueAddress(vaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvaddOp.getA());
auto bAddr = memory.getValueAddress(vvaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvadd";
@@ -232,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = totalBytes;
json["len"] = getValueSizeInBytes(vvaddOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vmaxOp.getA());
auto bAddr = memory.getValueAddress(vmaxOp.getB());
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvsubOp.getA());
auto bAddr = memory.getValueAddress(vvsubOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvsub";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmulOp.getA());
auto bAddr = memory.getValueAddress(vvmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
@@ -248,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvdmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
auto aAddr = memory.getValueAddress(vavgOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getA());
emitInstruction(std::move(json));
}
@@ -261,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
auto aAddr = memory.getValueAddress(vtanhOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
auto aAddr = memory.getValueAddress(vsigmOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getA());
emitInstruction(std::move(json));
}
@@ -318,12 +432,59 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
vaddJson["rs1"] = 1;
vaddJson["rs2"] = 2;
vaddJson["offset"] = createEmptyOffset();
vaddJson["len"] = 32 * outChannels;
emitInstruction(std::move(vaddJson));
}
}
}
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getData());
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
@@ -343,7 +504,6 @@ std::string getMemorySizeAsString(size_t size) {
/// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
@@ -355,9 +515,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (getGlobalOp->hasAttr("weightAlways"))
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
auto initialValue = globalOp.getInitialValue();
@@ -393,30 +553,46 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op))
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
coreCodeGen.codeGenVMaxOp(vmaxOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (isa<pim::PimSumOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation");
continue;
@@ -450,7 +626,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
continue;
}
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
@@ -539,7 +715,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (Value output : returnOp.getOperands())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
@@ -564,9 +740,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
}
}
auto funcOps = moduleOp.getOps<func::FuncOp>();
assert(!funcOps.empty() && "No function found in the module");
auto funcOp = *funcOps.begin();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc))
return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);

View File

@@ -3,69 +3,62 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/Support/JSON.h"
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
using namespace llvm;
using namespace mlir;
using Value = mlir::Value;
using Type = mlir::Type;
using FunctionType = mlir::FunctionType;
struct MemEntry {
size_t address;
size_t size;
};
class PimMemory {
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(Value value);
void allocateMemoryForValue(Value value, MemEntry& memEntry);
MemEntry* gatherMemEntry(mlir::Value value);
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public:
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
void allocateCore(Operation* op);
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(Value value) const;
MemEntry getMemEntry(mlir::Value value) const;
};
class PimAcceleratorMemory {
public:
SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
SmallDenseMap<size_t, PimMemory> deviceMem;
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
PimMemory getOrCreateDeviceMem(size_t id);
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(Value value) const;
size_t getValueAddress(mlir::Value value) const;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
raw_fd_ostream& coreFileStream;
llvm::raw_fd_ostream& coreFileStream;
static json::Object createEmptyOffset();
void emitInstruction(json::Object instruction) const;
static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
@@ -73,17 +66,23 @@ class PimCodeGen {
void setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
void
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const;
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
void emitMemCopyOp(mlir::StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
mlir::StringRef sizeFieldName = "size") const;
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public:
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson)
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
void codeGenSendOp(pim::PimSendOp sendOp) const;
@@ -91,12 +90,19 @@ public:
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName);
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir

View File

@@ -25,7 +25,6 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> exportCrossbarWeights;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;

View File

@@ -2,10 +2,9 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#define DEBUG_TYPE "PimCompilerUtils"
@@ -34,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPIMPass());
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
}
@@ -46,6 +45,11 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createMaterializeConstantsPass());
pm.addPass(createVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));

View File

@@ -1,3 +1,3 @@
add_subdirectory(ONNXToSpatial)
add_subdirectory(SpatialToGraphviz)
add_subdirectory(SpatialToPIM)
add_subdirectory(SpatialToPim)

View File

@@ -2,33 +2,29 @@ set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.cpp
Math/Conv.cpp
Math/ExperimentalConv.cpp
Math/ExperimentalGemm.cpp
NN/Pooling.cpp
NN/ExperimentalPooling.cpp
NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp
ONNXToSpatialPass.hpp
add_pim_library(OMONNXToSpatial
Patterns/Math/Gemm.cpp
Patterns/Math/Conv.cpp
Patterns/Math/MatMul.cpp
Patterns/NN/Pool.cpp
Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp
ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPimCompilerOptions
OMONNXOps
SpatialOps
OMPIMCommon
OMPimCommon
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)

View File

@@ -15,7 +15,7 @@
#include <optional>
#include <utility>
#include "ONNXToSpatialCommon.hpp"
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -47,7 +47,7 @@ SmallVector<Value> sliceTensor(
if (i == numSlices - 1 && lastSliceSize != 0)
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides);
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
slices.push_back(slice);
}
@@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult();
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
return rewriter.create<tensor::SplatOp>(loc, type, elementValue);
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
}
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
@@ -122,7 +122,7 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
@@ -137,10 +137,10 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
switch (mapOp) {
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input);
case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
}
}
@@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor,
ConversionPatternRewriter& rewriter) {
ShapedType imageShape = mlir::cast<ShapedType>(imageTensor.getType());
size_t input_h = GET_IMAGE_HEIGHT(imageShape);
size_t input_w = GET_IMAGE_WIDTH(imageShape);
size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize);
size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize;
size_t input_h = getImageHeight(imageShape);
size_t input_w = getImageWidth(imageShape);
size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize);
size_t tileRest = getImageChannel(imageShape) % tileSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(4, rewriter.getIndexAttr(0));
@@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor,
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides);
tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
}
}
}
@@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat);
return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
}
LogicalResult
@@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice,
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
}
Value indexImgValue(Value v,
@@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides);
inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
}
}
}
@@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
Value shapeTensor =
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);

View File

@@ -9,24 +9,53 @@
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1)
#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0)
#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2)
#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3)
#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0)
using namespace mlir;
namespace onnx_mlir {
const StringRef REPLICATION_ATTR_NAME = "replication_factor";
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
@@ -58,51 +87,64 @@ constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
}
template <class T>
bool isVectorShape(const ArrayRef<T> shape) {
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(const ArrayRef<T> shape) {
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(const ArrayRef<T> shape) {
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(const ArrayRef<T> shape) {
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(const ArrayRef<T> shape) {
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(const Value tensor) { return cast<RankedTensorType>(tensor.getType()).getShape(); }
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input);
mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input);
/**
* Unpacks an optional pair vector into two size_t values.
@@ -126,7 +168,8 @@ void unpackOptionalPairVector(std::optional<mlir::ArrayAttr> valuesArray, size_t
*
* @return llvm::Optional<llvm::Twine> The error message if the pads are invalid
*/
std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
std::optional<llvm::Twine>
unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> valuesArray, size_t& pad_x, size_t& pad_y);
/**
* Tiles the image tensor by channel.
@@ -140,10 +183,10 @@ std::optional<Twine> unpackOptionalPadsVector(std::optional<mlir::ArrayAttr> val
* @param tileSize The size of each tile.
* @param rewriter The ConversionPatternRewriter used for creating operations.
*/
void tileImageTensorByChannel(Value imageTensor,
SmallVector<SmallVector<SmallVector<Value>>>& tiles,
void tileImageTensorByChannel(mlir::Value imageTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& tiles,
size_t tileSize,
ConversionPatternRewriter& rewriter);
mlir::ConversionPatternRewriter& rewriter);
/**
* Creates an ImgConcatOp based on the given tiles.
@@ -159,10 +202,10 @@ void tileImageTensorByChannel(Value imageTensor,
*
* @return The created ImgConcatOp.
*/
Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTiles,
ConversionPatternRewriter& rewriter,
Location& loc,
Type outputType);
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& outputTiles,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc,
mlir::Type outputType);
/**
* @brief Verifies if the given input coordinates and padding values are within
@@ -177,7 +220,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
* @return LogicalResult Returns success if the coordinates and padding are
* within bounds, failure otherwise.
*/
LogicalResult
mlir::LogicalResult
verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y);
/**
@@ -207,13 +250,14 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY,
* @return std::optional<llvm::Twine> An error message if the input tensor could
* not be resolved into tiles.
*/
std::optional<Twine> resolveImgInputTiles(Value wholeInputTensor,
SmallVector<SmallVector<SmallVector<Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
mlir::ConversionPatternRewriter& rewriter);
std::optional<llvm::Twine>
resolveImgInputTiles(mlir::Value wholeInputTensor,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<mlir::Value>>>& inputTiles,
size_t channelTileCount,
size_t channelTileRest,
size_t input_w,
size_t input_h,
mlir::ConversionPatternRewriter& rewriter);
/**
* Computes the boundaries of an image kernel application.
@@ -258,6 +302,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom
* @return The index of the result of the operation that produces the specified
* value.
*/
int getResultIndex(Operation* op, Value v);
int getResultIndex(mlir::Operation* op, mlir::Value v);
}; // namespace onnx_mlir

View File

@@ -1,583 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
// NOTE:
// This might be useful to re-implement this considering for loops.
// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
/**
* @brief A momentary representation of a core, to be used within the tiling of
* a convolution operation.
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
*
* @param inputTile The input tile to the MVM operation.
* @param xbarIndex The index of the crossbar weight to use.
* @param outputTileId The id of the output tile.
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
// Move the insertion point to the end of the block.
rewriter.setInsertionPointToEnd(block.get());
// Add the inputTile to the block arguments, and to the operands.
Value operand = operandMap.lookupOrNull(inputTile);
if (not operand) {
operand = block->addArgument(inputTile.getType(), loc);
operands.push_back(inputTile);
operandMap.map(inputTile, operand);
}
// TODO: Compute the output type using the matrix, and check if `mvmOutType`
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
// MVM operations for the same outputTile.
auto lastMVM = outputTileToMVM.find(outputTileId);
// If an entry for this outputTile already exists, apply reduction.
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
return result;
}
/**
* @brief Mark a result as remappable, and return a shared pointer to it.
*
* This function marks a result as remappable, and returns a shared pointer to
* it. We need to keep track of these values to generate the YieldOp at a
* later stage.
*
* @param result A result to track, for later remapping.
* @return shared_ptr<Value> A shared pointer to the result.
*/
shared_ptr<Value> makeResultRemappable(Value result) {
// Verify that the result is present in the block.
assert(result.getDefiningOp()->getBlock() == block.get());
shared_ptr<mlir::Value> remappableResult = make_shared<Value>(result);
resultsToRemap.push_back(remappableResult);
results.push_back(result);
return remappableResult;
}
/**
* @brief Add a remappable operand to the core, to merge partial results
* inter-core.
*
* @param remappableOperand The operand to add.
* @return Value The block argument representing the operand.
*/
Value addRemappableOperand(std::shared_ptr<Value> operand) {
// Check that the operand is not already there.
assert(not operandMap.contains(*operand));
Value argument = block->addArgument(operand->getType(), operand->getLoc());
remappableOperands.push_back(operand);
return argument;
}
/**
* @brief Generate a spatial::SpatWeightedCompute operation from the core.
*
* @param loc The location of the operation.
* @return spatial::SpatWeightedCompute
*/
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto& value : results)
resultTypes.push_back(value.getType());
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
rewriter.setInsertionPointToEnd(releasedBlock);
rewriter.create<spatial::SpatYieldOp>(loc, results);
return wcomputeOp;
}
/**
* @brief Remap the results to the WComputeOp results.
*/
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
assert(!isXbarsFull());
xbarWeights.push_back(weight);
return xbarWeights.size() - 1;
}
bool isXbarsFull() {
assert(xbarWeights.size() <= crossbarCountInCore);
return xbarWeights.size() == crossbarCountInCore;
}
bool isCoreEmpty() { return block->empty(); }
void dump() {
// Print the coreId
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights)
weight.dump();
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands)
llvm::outs() << operand << "\n";
// Dump the body block
for (auto& op : block->getOperations())
op.dump();
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results)
llvm::outs() << result << "\n";
}
const size_t coreId;
private:
ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
vector<std::shared_ptr<Value>> remappableOperands;
vector<Value> results;
vector<std::shared_ptr<Value>> resultsToRemap;
// Maps from input tiles to the block operand
IRMapping operandMap;
// Map from outputTileId to MVM operation producing it
unordered_map<size_t, Value> outputTileToMVM;
vector<Value> xbarWeights;
unique_ptr<mlir::Block> block = make_unique<Block>();
spatial::SpatWeightedCompute wcomputeOp;
};
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t {
Value value;
shared_ptr<Core> core;
};
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
ShapedType yShape = mlir::cast<ShapedType>(conv.getY().getType());
size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y;
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
// TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case.
// MapOperations mapOperation = MapOperations::None;
//
// // If we have just one user, and it is an activation funcion (or more in
// // general a mapping operation) just inline it in the computeOps
// auto firstUserOp = *conv->getUsers().begin();
// if (conv->hasOneUse()) {
// mapOperation = mlirOpToMapOperationEnum(firstUserOp);
//
// if (mapOperation == MapOperations::ONNXSoftmaxOp) {
// return rewriter.notifyMatchFailure(
// conv, "Softmax not supported as activation for convolutions.");
// }
// }
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
Location loc = conv.getLoc();
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
// a. Filter Tile
// b. Channel Tile
// c. Kernel `x` position
// d. Kernel `y` position
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] =
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
}
/* Distribute the computation among many compute cores
* Try to compute in-core the computation for each output tile, and reduce
* over as few cores as possible
*/
// Tile the output tensor
// Output tiles need to be indexed by:
// a. Filter Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1;
else
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter);
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
size_t out_y = 0;
// I use `replicationFactor` cores. I divide the input_w into
// `replicationFactor` slices, and each slice is distributed to a
// core. `coreIndex` is the index of the core that will be used
// for this slice
size_t coreIndex = in_x / replicationSliceSize;
assert(coreIndex < replicationFactor);
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
out_y++;
continue;
}
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
out_y++;
}
out_x++;
}
// Computations for these crossbars are done, check if the cores
// crossbars are fully used. If full, swap with new core
for (size_t i = 0; i < replicationFactor; i++) {
if (curCores[i]->isXbarsFull()) {
cores.emplace_back(std::move(curCores[i]));
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
}
}
}
}
}
for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore));
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
for (size_t out_x = 0; out_x < output_w; out_x++) {
for (size_t out_y = 0; out_y < output_h; out_y++) {
// First, check if some producers are within the same core. If this is
// true, `Core::addMVM` have already done the reduction within-core.
// This means that we only need to consider the last producer for that
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer;
// Now, we need to apply inter-core reduction
// Base case with one producer
if (withinCoreReducedProducers.size() == 1) {
// TODO: Add the bias and apply mapping (if present)
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
// TODO: This is a linear reduction, not a tree reduction. We can do
// better: a tree reduction would make more computations happen in
// parallel.
Producer_t lastProducer = withinCoreReducedProducers.begin()->second;
auto it = withinCoreReducedProducers.begin();
it++;
while (it != withinCoreReducedProducers.end()) {
Producer_t curProducer = it->second;
shared_ptr<Core> core1;
shared_ptr<Core> core2;
Value core1Value;
Value core2Value;
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId
&& "We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
core1 = curProducer.core;
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
}
else {
core1 = lastProducer.core;
core1Value = lastProducer.value;
core2 = curProducer.core;
core2Value = curProducer.value;
}
auto newCoreRes = core1->makeResultRemappable(core1Value);
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2};
it++;
}
// TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
}
}
}
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp;
for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto& core : cores)
core->addRemappedOperands();
// Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp);
SmallVector<Value> tilesToConcat;
tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize);
for (size_t outX = 0; outX < output_h; outX++)
for (size_t outY = 0; outY < output_w; outY++)
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
// Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
// If no mapping (activation) was applied, just replace ConvOp
// if (mapOperation == MapOperations::None) {
// rewriter.replaceOp(conv, outputImage);
// } else {
// // If mapping was applied, erase ConvOp and replace the mapping op
// rewriter.eraseOp(conv);
// rewriter.replaceOp(firstUserOp, outputImage);
// }
return success();
}
};
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,400 +0,0 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A pattern to tile the convolution operation into a series of compute
* units, each one of which applies filters to a subset of the input
* tensor. Results are also reduced and concatenated to form the final
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(convAdaptor.getX().getType());
ShapedType outputType = cast<ShapedType>(conv.getY().getType());
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1
&& "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
// TODO: Address bias addition.
ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize);
ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize);
size_t kernelWidth = GET_KERNEL_WIDTH(weightsType);
size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType);
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(4, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
auto offsets = extractSliceOp.getStaticOffsets();
xKerPos.push_back(offsets[2]);
yKerPos.push_back(offsets[3]);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(conv.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
return success();
}
}
// Return the final output.
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
};
/**
* @brief Populate the tiling pattern for a convolution operation.
*
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,365 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
// --------------------------------- //
// To get each crossbar's weights, we need to slice the weights tensor.
// - Along the input tiles.
// - Along the output tiles.
// - Along the filter x position.
// - Along the filter y position.
ShapedType inputType = cast<ShapedType>(adaptor.getA().getType());
ShapedType outputType = cast<ShapedType>(gemmOp.getY().getType());
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
size_t kernelWidth = 1;
size_t kernelHeight = 1;
// Assert that the kernel is square.
assert(kernelWidth == kernelHeight && "Only square kernels are supported.");
// -------------------------------- //
// --- SLICE THE WEIGHTS TENSOR --- //
// -------------------------------- //
// The core idea of this stage is classifying the weights by input and
// output tile. This is because we want the applyFilters operations to be
// tile agnostic, to keep the subsequent lowering stages as simple as
// possible. This data structure does this weight classification:
// - The outer map is indexed by input tile.
// - The inner map is indexed by output tile.
// - The SmallVector contains the weights for the filter.
map<long, map<long, SmallVector<Value>>> weightsGroups;
// During all slicing operations within this stage, we'll use the same
// strides for all dimensions.
SmallVector<OpFoldResult> slicingStrides(2, rewriter.getIndexAttr(1));
ldiv_t itc = inputTileCount;
ldiv_t otc = outputTileCount;
// - Slicing along the input tiles.
// - Slicing along the output tiles.
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize;
for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) {
long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize;
// The loop above also sets the crossbar's used width and height,
// checking if we're at the last crossbar and if it's incomplete.
long outputTile = ot;
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
for (size_t filterX = 0; filterX < kernelWidth; ++filterX) {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
}
}
}
}
// TODO: Tree reduction for compute reduction should be implemented.
// -------------------------------- //
// --- CREATE ALL COMPUTE UNITS --- //
// -------------------------------- //
// Keep track of input slicing operations to avoid duplication across
// all compute units (global slices).
map<long, Value> globalSlices;
// Keep track of all partial compute results.
map<long, Value> globalPartialResults;
// Use a weight subdivider to extract groups of weights for each compute
// unit. We'll keep extracting groups until no more weights are left.
WeightSubdivider weightSubdivider(weightsGroups);
while (!weightSubdivider.isEmpty()) {
// -------------------------------- //
// --- BEGIN A NEW COMPUTE UNIT --- //
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
// ------------------------------ //
// --- SLICE THE INPUT TENSOR --- //
// ------------------------------ //
// Note each tile's index in the compute unit arguments.
map<long, size_t> inputTileIndices;
map<long, size_t> outputTileIndices;
map<long, size_t> reductionTileIndices; // Incoming partial results.
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
// new one.
if (globalSlices.find(group.inputTile) != globalSlices.end()) {
computeOperands.push_back(globalSlices[group.inputTile]);
localSlices[group.inputTile] = globalSlices[group.inputTile];
continue;
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
// Update slicing maps.
globalSlices[group.inputTile] = extractSliceOp;
localSlices[group.inputTile] = extractSliceOp;
// Update the input tile index.
inputTileIndices[group.inputTile] = computeOperands.size() - 1;
}
// ------------------------------- //
// --- PREPARE THE OUTPUT TYPE --- //
// ------------------------------- //
// Fill the compute output's type by looking at the output tiles.
SmallVector<Type> computeOutputType;
for (TaggedWeights group : weightsGroups) {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
// ----------------------------- //
// --- FILL THE COMPUTE UNIT --- //
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc());
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
for (auto weight : group.weights) {
// Assert that the weight is an extract_slice operation.
auto extractSliceOp = weight.getDefiningOp<tensor::ExtractSliceOp>();
assert(extractSliceOp && "Weight is not an extract_slice operation.");
// Get the filter x and y positions from the extract_slice operation.
xKerPos.push_back(0);
yKerPos.push_back(0);
}
ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices);
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
localPartialResults[group.outputTile] = result;
}
// Add a yield operation to the block by concatenating the partial
// results.
SmallVector<Value> applyFiltersResults;
for (size_t i = 0; i < computeOutputType.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
// Get that tile's partial result and add it to the list.
applyFiltersResults.push_back(localPartialResults[outputTile]);
}
// Create the yield operation with the given results.
rewriter.create<spatial::SpatYieldOp>(gemmOp.getLoc(), applyFiltersResults);
// Update the global partial results map.
for (size_t i = 0; i < applyFiltersResults.size(); ++i) {
long outputTile;
// Given an output tile index, find the corresponding output tile.
for (auto outputTileIndex : outputTileIndices) {
if (outputTileIndex.second == i) {
outputTile = outputTileIndex.first;
break;
}
}
globalPartialResults[outputTile] = currentCompute.getResult(i);
}
// Move the rewrite cursor out of the block.
rewriter.setInsertionPointAfter(currentCompute);
}
// ------------------------------ //
// --- CONCATENATE THE OUTPUT --- //
// ------------------------------ //
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,317 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
ONNXGemmOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location gemmLoc = gemmOp.getLoc();
Value a = adaptor.getA();
Value b = adaptor.getB();
Value c = adaptor.getC();
Value out = gemmOp.getY();
float alpha = adaptor.getAlpha().convertToFloat();
float beta = adaptor.getBeta().convertToFloat();
bool transA = adaptor.getTransA();
bool transB = adaptor.getTransB();
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(out.getType());
RankedTensorType cType = nullptr;
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
assert("Only support 2 tensor for C" && cType.getRank() == 2);
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
}
if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
}
if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
auto outLastHSliceSize = cLastHSliceSize;
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
SmallVector<Value> cHSlices;
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
if (hasC)
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
RankedTensorType outHSliceType =
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
RankedTensorType outLastHSliceType =
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
SmallVector<Value> outHSlices;
outHSlices.reserve(outNumHSlices);
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
RankedTensorType currOutHSliceType = outHSliceType;
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
currOutHSliceType = outLastHSliceType;
SmallVector<Value> partialResults;
partialResults.reserve(coresPerVSlice);
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
SmallVector<Value> weights;
weights.reserve(aHSlices[coreId].size());
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId])
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
auto computeArgs = computeBlock->getArguments();
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back(
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp.getResult(0));
}
if (hasC) {
Value cHSlice = cHSlices[outSliceId];
partialResults.push_back(cHSlice);
}
auto reduceComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block();
for (auto partialResult : partialResults)
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
reduceComputeOp.getBody().push_back(reduceBlock);
rewriter.setInsertionPointToStart(reduceBlock);
auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp);
outHSlices.push_back(reduceComputeOp.getResult(0));
}
rewriter.setInsertionPoint(gemmOp);
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
rewriter.replaceOp(gemmOp, concatOp);
return success();
}
private:
/**
* Resolves the ONNXExpOp from the use chain of the given start value.
*
* This function traverses the use chain of the start value until it finds an
* ONNXExpOp. It returns the value of the ONNXExpOp.
*
* @param startValue The starting value of the use chain.
* @return The value of the ONNXExpOp found in the use chain.
*/
static Value resolveONNXExpOpFromUseChain(Value startValue) {
Value walker = startValue;
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
walker = walker.getDefiningOp()->getOperand(0);
assert(walker && walker.getDefiningOp()
&& "Unwinded the whole chain of operations while trying to "
"find ONNXExpOp, but did not find it");
}
// Make sure the dividend is actually produced by an ONNXExpOp
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
&& "Old output tile (softmax reducer) is not produced by an "
"ONNXExpOp");
return walker;
}
// Softmax is a special case, as it requires another reduction after the
// first one. In the cores, `applyReducePattern` already applied
// f(x) = exp(x) to each tile. This mean that now we just need to
// reduce-sum these tiles, and then divide each tile by the reduced sum,
// which is propagated back to the cores via a broadcast channel.
LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
Value& softmaxChannel,
ConversionPatternRewriter& rewriter,
SpatialReducer& reducer,
ONNXGemmOp& gemmOp,
Location& loc) const {
// TODO: Check case with one compute op
// Cast vector of Value into vector of ComputeOp
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
}));
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
const TensorType scalarTensorType = tensorTypeBuilder;
reducer.applyReducePattern(
softmaxOpsToReduce,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
/* preprocess = */
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
[&](Value softmaxDivisor) {
// Signal that this is the compute with the softmax divisor
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
// Broadcast the divisor to all the cores
rewriter.setInsertionPointAfterValue(softmaxDivisor);
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
/*
* softmaxDividend = onnx.exp (...)
* sum = spat.SumOp(softmaxDividend)
* [following can be repeated N times, thus walk the use chain]
* softmaxDivisor = spat.sadd(sum, ...)
*/
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
// Make sure the dividend is actually produced by an ONNXExpOp
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
&& "Dividend of softmax reduction is not an ONNXExpOp");
// Do not divide here, divide after this
return softmaxDivisor;
});
// In all the cores, insert a ChannelRecvOp and divide the output tile by
// the reduced denominator.
outputOpsAndResNums.clear();
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
Value divisor;
// Check if this compute contains the softmax divisor: if so, find the
// ChannelBroadcastSendOp, otherwise receive the value from the channel
// using ChannelBroadcastReceiveOp
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
bool found = false;
for (auto broadcastOp :
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
assert(found == false
&& "More than one ChannelBroadcastSendOp in "
"compute? How is this possible?");
found = true;
divisor = broadcastOp.getData();
}
assert(found
&& "No ChannelBroadcastSendOp in compute where softmax "
"divisor was specified to be?");
}
else {
rewriter.setInsertionPoint(yieldOp);
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
}
// Walk the chain of operations until we find the ONNXExpOp: this is
// needed because some some may have a different amount of `VAddOp`s due
// to the tree reduction (e.g. some may have no VAddOp, some may have
// multiples)
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
rewriter.setInsertionPoint(yieldOp);
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
auto yieldOperandNum = yieldOp->getNumOperands();
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
}
return success();
}
};
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXGemmOpTile>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,300 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename PoolOp>
bool hasPostProcessExperimentalPoolingWindow() {
return false;
}
template <>
bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
if (inputTiles.size() == 1)
return inputTiles[0];
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize);
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat)
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
// For each argument of the tensor.ConcatOp, resolve the input tiles.
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
}
}
// Prepare the shape of the compute's output.
ldiv_t itc = tileCount;
SmallVector<Type> outputTileTypes;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
}
}
}
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]);
}
// Create a single compute to calculate the output.
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
// Begin writing in the block.
rewriter.setInsertionPointToStart(block);
// Go through all pooling blocks.
SmallVector<Value> outputTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
size_t start_x = x * stride_x;
size_t start_y = y * stride_y;
size_t end_x = std::min(start_x + krn_w, input_w);
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue =
rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult =
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
}
}
// Create a YieldOp to return the output tiles.
rewriter.create<spatial::SpatYieldOp>(loc, outputTiles);
// Set the rewrite cursor right after the computeOp.
rewriter.setInsertionPointAfter(computeOp);
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> computeOutput;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
}
// We'll now create spat.img.concat ops to concatenate the output tiles.
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]);
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long) tilingSize,
/* 2 */ (long) output_w,
/* 3 */ (long) output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
return success();
}
};
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
ctx);
}
} // namespace onnx_mlir

View File

@@ -1,427 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1)
return valuesToReduce[0];
if (preprocess) {
for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
}
// It is possible that `valuesToReduce` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& valToReduce : valuesToReduce) {
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
// } else {
// // Otherwise it is a block argument,
// computeOp->getBlock()->getParentOp();
// }
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
auto it = lastValueForCompute.find(computeOp);
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
else
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
lastValueForCompute[computeOp] = valToReduce;
}
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second);
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the valuesToReduce list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
while (valuesToReduceRef.size() > 1) {
SmallVector<Value> nextValuesToReduce;
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
auto firstValue = valuesToReduceRef[i];
auto secondValue = valuesToReduceRef[i + 1];
auto firstCompute = firstValue.getParentBlock()->getParentOp();
auto secondCompute = secondValue.getParentBlock()->getParentOp();
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstValue, secondValue);
std::swap(firstCompute, secondCompute);
}
// 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute);
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
// 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue);
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue);
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
Value reduced = reduce(receivedValue, secondValue);
nextValuesToReduce.push_back(reduced);
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back());
// Replace the inputOps list with the new one.
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
if (postprocess) {
rewriter.setInsertionPointAfterValue(finalValue);
finalValue = postprocess(finalValue);
}
return finalValue;
}
template <typename PoolOp>
bool hasPostProcessPoolingWindow() {
return false;
}
template <>
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = GET_IMAGE_HEIGHT(xShape);
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt =
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
// If some input tiles come from the func.func operands, load
// them into a computeOp and yield them
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0);
}
}
}
}
// 2: Tile the output tensor
// Output tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
// Each output pixel tile is computed by pooling a window of input
// pixel tiles
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
}
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
}
else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool,
rewriter,
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
nullptr,
postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced =
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel =
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue =
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
}
}
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
auto outputTile = outputTiles[outTile][outX][outY];
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
return success();
}
};
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,88 +0,0 @@
#include "mlir/Transforms/DialectConversion.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
ReduceMeanConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// Get the input tensor.
Value inputTensor = adaptor.getData();
auto inputTensorType = cast<RankedTensorType>(inputTensor.getType());
// This pattern will substitute the ONNXReduceMeanV13Op with a
// ONNXAveragePoolOp with the same input tensor and an appropriate kernel
// shape and strides.
// To get the stride and shape of the kernel, we need to read the tensor
// shape.
int image_height = inputTensorType.getShape()[2];
int image_width = inputTensorType.getShape()[3];
// Define the kernel shape and strides.
SmallVector<int64_t> kernelShapeVals = {image_height, image_width};
SmallVector<int64_t> stridesVals = {image_height, image_width};
SmallVector<int64_t> dilationsVals = {1, 1};
// Set the pads to 0.
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
// Create the ArrayAttrs
auto kernelShape = mlir::ArrayAttr::get(
rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto dilations = mlir::ArrayAttr::get(
rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
// Create the resulting tensor type.
auto resultType = RankedTensorType::get(
/*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
/*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
resultType,
inputTensor,
/*auto_pad=*/"NOTSET",
/*ceil_mode=*/0,
/*count_include_pad=*/1,
dilations,
/*kernel_shape=*/kernelShape,
/*pads=*/pads,
/*strides=*/strides);
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
rewriter.replaceOp(reduceMean, averagePool.getResult());
return success();
}
};
void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ReduceMeanConversionPattern>(ctx);
}
} // namespace onnx_mlir

View File

@@ -13,9 +13,11 @@ def onnxToArithConstantOp : Pat<
(Arith_ConstantOp $value)
>;
//===----------------------------------------------------------------------===//
// ONNXMatMulOp to ONNXGemmOp patterns
//===----------------------------------------------------------------------===//
def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
@@ -24,24 +26,24 @@ def matMulAddToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
def matMulToGemmPattern : Pat<
(ONNXMatMulOp:$matmulres $A, $B),
(
ONNXGemmOp $A, $B,
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
//===----------------------------------------------------------------------===//
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
//===----------------------------------------------------------------------===//
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias.
@@ -55,9 +57,7 @@ def convAddToConvWithBiasPatternRight : Pat<
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
//===----------------------------------------------------------------------===//
// Operation to ignore (i.e. remove)
//===----------------------------------------------------------------------===//
def replaceWithOperationOfValue : NativeCodeCall<"$0">;

View File

@@ -1,31 +1,51 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "ONNXToSpatialPass.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace
void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -36,21 +56,23 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (annotateReplication(funcOp, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n";
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<ONNXDialect, SpatialDialect, tensor::TensorDialect, arith::ArithDialect, tosa::TosaDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
});
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
@@ -59,23 +81,17 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXReshapeOp>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
if (useExperimentalConvImpl) {
populateExperimentalTilingConvOpPattern(patterns, ctx);
populateExperimentalPoolingTilingPattern(patterns, ctx);
populateGemmToConvConversionPattern(patterns, ctx);
}
else {
populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateTilingGemmOpPattern(patterns, ctx);
}
populateConvOpPatterns(patterns, ctx);
populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
@@ -85,8 +101,8 @@ void ONNXToSpatialPass::runOnOperation() {
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations())
if (isa<SpatWeightedCompute>(op))
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
@@ -96,29 +112,26 @@ void ONNXToSpatialPass::runOnOperation() {
}
}
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp);
annotateWeightsConstants(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<SpatWeightedCompute>(user); });
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight)
constantOp->setAttr("weightAlways", UnitAttr::get(ctx));
markWeightAlways(constantOp);
});
}
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,34 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
using namespace mlir;
extern bool haveSameStaticShape(Value lhs, Value rhs);
namespace spatial {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
};
} // namespace spatial
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<spatial::ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -1,28 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64(*padsAttr, 0);
padWidthBegin = getI64(*padsAttr, 1);
padHeightEnd = getI64(*padsAttr, 2);
padWidthEnd = getI64(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB)
gemmC = b;
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
auto im2colComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
auto* im2colBlock = new Block();
im2colBlock->addArgument(x.getType(), loc);
im2colComputeOp.getBody().push_back(im2colBlock);
rewriter.setInsertionPointToStart(im2colBlock);
Value paddedInput = im2colBlock->getArgument(0);
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
}
// Build im2col [numPatches, patchSize]:
// For each batch/output position (n, oh, ow), extract the patch from x
SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches);
for (int64_t n = 0; n < batchSize; n++) {
for (int64_t oh = 0; oh < outHeight; oh++) {
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
}
}
}
// Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
rewriter.setInsertionPointAfter(im2colComputeOp);
// Gemm: A @ B + C = im2col @ W^T + b
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2colComputeOp.getResult(0),
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
Value gemmOut = gemmOp.getY();
auto collectComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
auto* collectBlock = new Block();
collectBlock->addArgument(gemmOut.getType(), loc);
collectComputeOp.getBody().push_back(collectBlock);
rewriter.setInsertionPointToStart(collectBlock);
auto gemmOutArg = collectBlock->getArguments().front();
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOutArg,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
return success();
}
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,360 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<Value>
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
if (factor == 1.0f)
return value;
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
const int64_t numOutRows = aType.getDimSize(0);
// Only decompose when there are multiple rows to split
if (numOutRows <= 1)
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType,
aSlice,
b,
cSlice,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block();
for (auto gemvOp : gemvOps)
concatBlock->addArgument(gemvOp.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location gemmLoc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
Value out = gemmOp.getY();
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
float beta = gemmOpAdaptor.getBeta().convertToFloat();
bool transA = gemmOpAdaptor.getTransA();
bool transB = gemmOpAdaptor.getTransB();
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(out.getType());
RankedTensorType cType = nullptr;
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv
return failure();
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
}
if (alpha != 1.0f) {
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
alpha = 1.0f;
}
if (hasC && beta != 1.0f) {
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
beta = 1.0f;
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
auto outLastHSliceSize = cLastHSliceSize;
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
SmallVector<Value> cHSlices;
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
if (hasC)
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
RankedTensorType outHSliceType =
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
RankedTensorType outLastHSliceType =
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
SmallVector<Value> outHSlices;
outHSlices.reserve(outNumHSlices);
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
RankedTensorType currOutHSliceType = outHSliceType;
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
currOutHSliceType = outLastHSliceType;
SmallVector<Value> partialResults;
partialResults.reserve(coresPerVSlice);
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
SmallVector<Value> weights;
weights.reserve(aHSlices[coreId].size());
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId])
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
auto computeArgs = computeBlock->getArguments();
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp.getResult(0));
}
if (hasC) {
Value cHSlice = cHSlices[outSliceId];
partialResults.push_back(cHSlice);
}
auto reduceComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block();
for (auto partialResult : partialResults)
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
reduceComputeOp.getBody().push_back(reduceBlock);
rewriter.setInsertionPointToStart(reduceBlock);
auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp);
outHSlices.push_back(reduceComputeOp.getResult(0));
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
auto* concatBlock = new Block();
for (auto outHSlice : outHSlices)
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure();
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
auto* concatBlock = new Block();
for (Value gemmRow : gemmRows)
concatBlock->addArgument(gemmRow.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
Value result = ONNXTransposeOp::create(
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();
}
};
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
}
template <typename PoolOp>
struct PoolToSpatialCompute;
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
using OpConversionPattern<PoolOp>::OpConversionPattern;
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = poolOp.getLoc();
Value x = adaptor.getX();
auto xType = dyn_cast<RankedTensorType>(x.getType());
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
if (xType.getRank() != 4 || outType.getRank() != 4)
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
ArrayAttr kernelAttr = poolOp.getKernelShape();
if (!kernelAttr || kernelAttr.size() != 2)
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
const int64_t batchSize = xType.getDimSize(0);
const int64_t channels = xType.getDimSize(1);
const int64_t inputHeight = xType.getDimSize(2);
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
int64_t padBottom = 0;
int64_t padRight = 0;
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
if (autoPad == "SAME_UPPER") {
padTop = totalPadH / 2;
padBottom = totalPadH - padTop;
padLeft = totalPadW / 2;
padRight = totalPadW - padLeft;
}
else {
padBottom = totalPadH / 2;
padTop = totalPadH - padBottom;
padRight = totalPadW / 2;
padLeft = totalPadW - padRight;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
auto* computeBlock = new Block();
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.replaceOp(poolOp, computeOp.getResult(0));
return success();
}
};
template <>
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
template <>
struct PoolToSpatialCompute<ONNXAveragePoolOp>
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,15 +1,15 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext* ctx)
struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
};
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx);
patterns.insert<Concat>(ctx);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,121 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(sourceIdx);
while (sourceProduct != resultProduct) {
if (sourceProduct > resultProduct)
return false;
sourceIdx++;
if (sourceIdx >= sourceShape.size())
return false;
group.push_back(sourceIdx);
sourceProduct *= sourceShape[sourceIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(resultIdx);
while (resultProduct != sourceProduct) {
if (resultProduct > sourceProduct)
return false;
resultIdx++;
if (resultIdx >= resultShape.size())
return false;
group.push_back(resultIdx);
resultProduct *= resultShape[resultIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
ONNXReshapeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
return failure();
if (sourceType == resultType) {
rewriter.replaceOp(reshapeOp, adaptor.getData());
return success();
}
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
return failure();
}
};
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Reshape>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,35 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : public OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,119 +0,0 @@
#include <queue>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
/**
* @brief Structure that describes the replication of a convolution operation,
* along the image height axis.
*/
struct ConvReplication {
ONNXConvOp convOp; // Convolution operation
size_t input_w; // Width of the input image
size_t replicationFactor; // Replication factor on the image height axis
size_t coresNeededPerReplica; // Number of cores needed for each replica
friend bool operator<(const ConvReplication& a, const ConvReplication& b) {
return a.input_w / a.replicationFactor < b.input_w / b.replicationFactor;
}
ConvReplication(ONNXConvOp convOp, size_t input_w, size_t replicationFactor, size_t coresNeededPerReplica)
: convOp(convOp),
input_w(input_w),
replicationFactor(replicationFactor),
coresNeededPerReplica(coresNeededPerReplica) {}
};
LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter) {
if (coresCount == -1) {
// No need for annotation, implicitly set replication to 1
return success();
}
std::priority_queue<struct ConvReplication> convOpsReplicationQueue;
size_t minimumCores = 0;
for (auto& op : funcOp.getFunctionBody().begin()->getOperations()) {
if (auto convOp = dyn_cast<ONNXConvOp>(op)) {
// Convolution layer
Value X = convOp.getX(), W = convOp.getW();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
ShapedType wShape = mlir::cast<ShapedType>(W.getType());
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t krn_h = GET_KERNEL_HEIGHT(wShape);
size_t krn_w = GET_KERNEL_WIDTH(wShape);
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue());
auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount;
auto neededCores = ceilIntegerDivide(neededXbars, crossbarCountInCore.getValue());
minimumCores += neededCores;
convOpsReplicationQueue.emplace(convOp, input_w, 1, neededCores);
}
else if (auto gemmOp = dyn_cast<ONNXGemmOp>(op)) {
// Fully connected layer
auto matrixTensorShape = cast<ShapedType>(gemmOp.getB().getType());
auto inputSize = matrixTensorShape.getDimSize(0);
auto outputSize = matrixTensorShape.getDimSize(1);
if (gemmOp.getTransB())
std::swap(inputSize, outputSize);
const size_t inputTilesCount = ceilIntegerDivide(inputSize, crossbarSize.getValue());
const size_t outputTilesCount = ceilIntegerDivide(outputSize, crossbarSize.getValue());
// Each output tile is computed by `coresPerOutputTile` cores. The
// entire input is given to each of these cores.
const size_t coresPerOutputTile = ceilIntegerDivide(inputTilesCount, crossbarCountInCore.getValue());
auto neededCores = coresPerOutputTile * outputTilesCount;
minimumCores += neededCores;
}
}
if (static_cast<size_t>(coresCount) < minimumCores) {
return funcOp->emitError("Not enough cores for this network: ")
<< minimumCores << " cores needed, but only " << static_cast<size_t>(coresCount) << " available.";
}
size_t availableCores = static_cast<size_t>(coresCount) - minimumCores;
// Consume all the elements in the queue
while (!convOpsReplicationQueue.empty()) {
auto convOpReplication = convOpsReplicationQueue.top();
convOpsReplicationQueue.pop();
// Check if we can replicate this convolution (e.g. we have enough cores)
if (availableCores > convOpReplication.coresNeededPerReplica * (convOpReplication.replicationFactor + 1)) {
// We can replicate this convolution: increment replicationFactor and put
// back in queue
availableCores -= convOpReplication.coresNeededPerReplica;
convOpReplication.replicationFactor++;
convOpsReplicationQueue.push(convOpReplication);
}
else {
// Cannot replicate this convolution anymore, annotate the operation
// with the replication factor
convOpReplication.convOp->setAttr(REPLICATION_ATTR_NAME,
rewriter.getI64IntegerAttr(convOpReplication.replicationFactor));
}
}
return success();
}
} // namespace onnx_mlir

View File

@@ -1,10 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir

View File

@@ -1,346 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <unordered_map>
#include <utility>
#include "SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
namespace onnx_mlir {
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter) {
assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result);
Value processedResult = processFun(result);
if (processedResult == result) {
// Sometimes we want processedResult to return the same value but do
// something else with it (e.g. in softmax we want to broadcast the value
// using a channel). In this case, we can just return the same value.
return resultNum;
}
yieldOp->insertOperands(yieldOp->getNumOperands(), processedResult);
return yieldOp.getNumOperands() - 1;
}
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
if (preprocess)
for (auto& computeOpAndResNum : computeOpsAndResNum)
GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
// It is possible that `computeOpsAndResNum` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation());
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
else
rewriter.setInsertionPointAfterValue(valueWithinCompute);
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
}
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
}
// Now, reconstruct from the map the computeOpsAndResNum list
computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false;
for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) {
// If the value is already used by the yieldOp, we can just use it
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
yieldOpUseFound = true;
break;
}
}
if (yieldOpUseFound)
continue;
// If this result is not used within a yieldOp, then add it
auto resultNum = yieldOp->getNumOperands();
yieldOp->insertOperands(resultNum, valueWithinCompute);
computeOpsAndResNum.push_back({computeOp, resultNum});
}
Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc();
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the computeOpsAndResNum list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<ComputeAndResNum> computeOpsRef(computeOpsAndResNum);
while (computeOpsRef.size() > 1) {
SmallVector<ComputeAndResNum> nextComputeOps;
nextComputeOps.reserve(computeOpsRef.size() / 2);
for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) {
auto [firstCompute, firstResultNum] = computeOpsRef[i];
auto [secondCompute, secondResultNum] = computeOpsRef[i + 1];
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstCompute, secondCompute);
std::swap(firstResultNum, secondResultNum);
}
// We do not immediately alter the computeOps results/operands, instead we
// do it in a delayed manner, to avoid invalidating the references to the
// computeOps (which must be replaced by a cloned ComputeOp when changing
// the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp
Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation
rewriter.setInsertionPoint(secondYield);
Value reduced = reduce(formerRes2, formerRes1);
// Unfortunately, it is not possible to update the result in place,
// because we may have already referenced it by <computeOp, resultNum>
// outside of this function, thus replacing it would invalidate the
// reference. Therefore, we need to append a new result to the yieldOp,
// and then at a later stage update the computeOp accordingly.
// Add `reduced` to the second yieldOp
auto secondYieldOperandNum = secondYield.getNumOperands();
secondYield->insertOperands(secondYieldOperandNum, reduced);
secondResultNum = secondYieldOperandNum;
// We should also add an entry for updating the results of the last
// operation (the one which never becomes a `firstCompute`): because it is
// not tracked by reducerChanges as `fromOp`
reducerChanges.push_back(
{firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum});
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (computeOpsRef.size() % 2 == 1)
nextComputeOps.push_back(computeOpsRef.back());
// Replace the inputOps list with the new one.
computeOpsRef = llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
}
assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalComputeAndResNum = computeOpsRef[0];
// Force the update of the results of this computeOp, when finalizing
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
if (postprocess)
GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum));
}
void SpatialReducer::finalizeReduceUpdates() {
assert(reducesFinalized == false && "Cannot finalize two times.");
reducesFinalized = true;
// First, add the results to the computeOps
for (auto& reduceChange : reducerChanges)
updateResultsOfCompute(reduceChange.fromOp);
for (auto& c : computeOpNeedingResUpdate)
updateResultsOfCompute(c.getOperation());
for (auto& reducerChange : this->reducerChanges) {
auto fromOp = reducerChange.fromOp;
auto toOp = reducerChange.toOp;
auto fromOpResNum = reducerChange.fromOpResNum;
auto toOpOperandNum = reducerChange.toOpOperandNum;
auto fromComputeOp = opToReplacedCompute[fromOp];
assert(fromComputeOp && "fromOp should have been mapped before!");
// toComputeOp could be the existing pointer, or we have to remap it with
// `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
assert(toComputeOp->getNumOperands() == toOpOperandNum
&& "toOpOperandNum should be the last operand of toComputeOp, are the "
"operations in the right order?");
// Add the new operand to `toComputeOp`
auto fromResult = fromComputeOp.getResult(fromOpResNum);
toComputeOp->insertOperands(toOpOperandNum, fromResult);
incrementWeightedComputeInputsSegmentSize(toComputeOp, 1);
}
}
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end())
opToCast = it->second;
else
opToCast = opAndResNum.first;
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second);
}
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again
return;
}
auto oldComputeOp = cast<spatial::SpatWeightedCompute>(computeOp);
auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map
opToReplacedCompute[oldComputeOp.getOperation()] = oldComputeOp;
return;
}
// Add the results by inspecting its YieldOp
auto newResultTypes = yieldOp.getOperandTypes();
// Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
auto newComputeOpNum = newComputeOp->getNumOperands();
assert(oldComputeOpNum == newComputeOpNum);
// Since we replaced the old ComputeOp with a new one, we need to replace
// all its results' uses
for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) {
Value oldResult = oldComputeOp.getResult(i);
Value newResult = newComputeOp.getResult(i);
// Replace the uses, except the uses of the compute ops which got deleted
// previously
rewriter.replaceAllUsesExcept(oldResult, newResult, oldComputeOpsReplaced);
}
// Finally, erase the old computeOp and update the map
opToReplacedCompute[oldComputeOp.getOperation()] = newComputeOp;
oldComputeOpsReplaced.insert(oldComputeOp.getOperation());
rewriter.setInsertionPoint(oldComputeOp);
rewriter.eraseOp(oldComputeOp);
}
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType) {
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
// outputTiles are indexed like this: [channelTile][x][y]
auto tilesCount = outputTiles.size();
auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++)
for (size_t y = 0; y < height; y++)
remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]);
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
}
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
MapOperations mapOp) {
std::function<Value(const Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) {
Value mapOperand = a;
if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand);
};
}
return this->applyReducePattern(
computeOps,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr,
postprocessing);
}
} // namespace onnx_mlir

View File

@@ -1,84 +0,0 @@
#pragma once
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange {
Operation* fromOp;
unsigned int fromOpResNum;
Operation* toOp;
unsigned int toOpOperandNum;
};
using OpAndResNum = std::pair<Operation*, ResNum>;
class SpatialReducer {
public:
SpatialReducer(ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {}
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
MapOperations mapOp);
void finalizeReduceUpdates();
~SpatialReducer() {
if (!reducesFinalized)
finalizeReduceUpdates();
}
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private:
[[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter);
/**
* @brief Update the results of a ComputeOp.
*
* This function updates the results of a ComputeOp by taking a look at the
operands of its yieldOp.
* If the ComputeOp was replaced, it updates `opToReplacedCompute` with the
replaced ComputeOp.
*
* @param computeOp The ComputeOp to update the results of.
*/
void updateResultsOfCompute(Operation* computeOp);
ConversionPatternRewriter& rewriter;
bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized
SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
};
} // namespace onnx_mlir

View File

@@ -1,53 +0,0 @@
#include <cassert>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin();
SmallVector<Value>& values = it->second.begin()->second;
long inputTile = it->first;
long outputTile = it->second.begin()->first;
size_t n = std::min(amount, values.size());
crossbarsUsed += n;
SmallVector<Value> result;
result.assign(values.begin(), values.begin() + n);
if (n < values.size()) {
values.erase(values.begin(), values.begin() + n);
}
else {
it->second.erase(outputTile);
if (it->second.empty())
weights.erase(inputTile);
}
return {inputTile, outputTile, crossbarsUsed - n, result};
}
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0;
SmallVector<TaggedWeights> result;
size_t remaining = n;
while (remaining > 0 && !weights.empty()) {
auto group = popGroup(remaining);
result.push_back(group);
remaining -= group.weights.size();
}
return result;
}
} // namespace onnx_mlir

View File

@@ -1,48 +0,0 @@
#pragma once
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <map>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
/**
* @brief A helper struct to store a group of weights.
*
*/
struct TaggedWeights {
long inputTile;
long outputTile;
size_t startingCrossbarIndex;
SmallVector<Value> weights;
};
/**
* @brief A helper class to subdivide weights into groups.
*
* Weights are stored as a map of maps of SmallVectors. The outer map is indexed
* by input tile, the inner map is indexed by output tile, and the SmallVector
* contains the weights for the filter. This class allows us to extract groups
* of weights from the map until we've extracted a certain number of elements,
* namely as many as we need to fill a compute unit.
*/
class WeightSubdivider {
private:
map<long, map<long, SmallVector<Value>>> weights;
size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount);
public:
WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights);
bool isEmpty() const;
SmallVector<TaggedWeights> popGroups(size_t n);
};
} // namespace onnx_mlir

View File

@@ -1,14 +1,17 @@
add_onnx_mlir_rewriter(SpatialToGraphviz)
add_onnx_mlir_library(OMSpatialToGraphviz
add_pim_library(OMSpatialToGraphviz
SpatialToGraphviz.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPIMCommon
OMPimCommon
OMONNXOps
SpatialOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)

View File

@@ -10,8 +10,9 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
@@ -199,12 +200,12 @@ private:
void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation();
// Get the first OP, must be a FuncOp
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
if (!func) {
module->emitError("No FuncOp found in the begin of module");
auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp func = *entryFunc;
os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n";

View File

@@ -1,21 +0,0 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.hpp
SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp
DEPENDS
SpatialToPIMIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPIMCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -1,28 +0,0 @@
#ifndef SPATIAL_TO_PIM
#define SPATIAL_TO_PIM
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/PIM/Pim.td"
#endif // OP_BASE
def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVMOp : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -1,108 +0,0 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape);
template <class T>
size_t rangeLength(const iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
/**
* Retrieves the earliest operation that uses the given value within the value's
* block.
*
* @param value The value for which to find the earliest user operation.
* @return The earliest user operation that uses the given value within the
* current block.
*/
Operation* getEarliestUserWithinBlock(Value value);
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation);
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation);
static bool isMemoryContiguous(const ArrayRef<int64_t> srcShape,
const ArrayRef<int64_t> offsets,
const ArrayRef<int64_t> sizes,
const ArrayRef<int64_t> strides) {
// Check that all strides are 1
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
// Check offsets from right to left:
// The first offset_n at position n different from 0:
// - limits all sizes to the left to 1
// - limits size_n to dimension_n - offset_n
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
// Check sizes from right to left:
// The first size_n at position n different from shape_n limits all sizes to the left to 1
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) {
return rewriter.create<tensor::EmptyOp>(loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(Operation* op) { return isa<tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op); }
} // namespace onnx_mlir

View File

@@ -1,60 +0,0 @@
#pragma once
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
namespace pim {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace pim
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<pim::SpatialToPIMPass>(); }
} // namespace onnx_mlir

View File

@@ -1,12 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
namespace spatial {
// TODO: Add here eventual patterns
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,23 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_GENERATED_INCLUDE_DIRS}
)

View File

@@ -5,9 +5,10 @@
#include <cassert>
#include <cstddef>
#include "SpatialToPIMCommon.hpp"
#include "Common.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
@@ -53,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
Operation* getEarliestUserWithinBlock(Value value) {
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();
assert(!users.empty());
@@ -66,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
return earliestUser;
}
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
auto operandsAndUses =
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
return {operand, std::distance(operand.use_begin(), operand.use_end())};
});
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
Value result = operation->getResult(0);
mlir::Value result = operation->getResult(0);
auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands =
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())
@@ -90,8 +92,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera
auto resultShapedType = cast<ShapedType>(resultType);
rewriter.setInsertionPoint(operation);
return rewriter.create<tensor::EmptyOp>(
operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
return tensor::EmptyOp::create(
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,52 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());
}
/**
* Retrieves the earliest operation that uses the given value within the value's
* block.
*
* @param value The value for which to find the earliest user operation.
* @return The earliest user operation that uses the given value within the
* current block.
*/
mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
}
inline bool isAConcatOp(mlir::Operation* op) {
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,48 @@
#ifndef SPATIAL_TO_PIM
#define SPATIAL_TO_PIM
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def onnxToPimTransposeOp : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMMOp : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVMOp : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMulOp : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMaxOp : Pat<
(SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
#endif // SPATIAL_TO_PIM

View File

@@ -1,8 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
@@ -16,21 +19,122 @@
#include <string>
#include <utility>
#include "SpatialToPIMPass.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace pim {
namespace {
void SpatialToPIMPass::runOnOperation() {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel,
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static size_t countComputeLeafUsers(Value value) {
size_t leafUserCount = 0;
auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) {
leafUserCount++;
continue;
}
if (!isChannelUseChainOp(owner))
llvm_unreachable("Channel use chain contains unsupported op");
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
};
walkUses(value, walkUses);
return leafUserCount;
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
@@ -40,15 +144,21 @@ void SpatialToPIMPass::runOnOperation() {
return;
}
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (!funcOp)
llvm_unreachable("No FuncOp found in the begin of module");
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp);
@@ -74,10 +184,10 @@ void SpatialToPIMPass::runOnOperation() {
}
// Dump to file for debug
dumpModule(moduleOp, "pim");
dumpModule(moduleOp, "pim0");
}
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front();
@@ -87,35 +197,67 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty())
continue;
auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isChannelUseChainOp(resultUser)) {
SmallVector<Operation*> returnChain;
Value chainedValue = result;
Operation* chainUser = resultUser;
while (isChannelUseChainOp(chainUser)) {
returnChain.push_back(chainUser);
auto chainUses = chainUser->getResult(0).getUses();
if (rangeLength(chainUses) != 1)
break;
chainedValue = chainUser->getResult(0);
chainUser = chainUses.begin()->getOwner();
}
if (isa<func::ReturnOp>(chainUser)) {
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
rewriter.setInsertionPoint(yieldOp);
IRMapping mapping;
mapping.map(result, yieldValue);
Value storedValue = yieldValue;
for (Operation* op : returnChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
storedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
markOpToRemove(op);
}
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
storedValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
continue;
}
}
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0;
@@ -125,13 +267,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(numElements * elementSize));
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(numElements * elementSize));
continue;
}
@@ -156,14 +299,14 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
continue;
}
}
@@ -175,23 +318,20 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// 1. Create a new ChannelOp
rewriter.setInsertionPoint(computeOp);
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel
// If this result is used by more than one user, then use a "Broadcast"
// channel operation. However, there is a special case: we have a single
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 2. Receive value through the channel. Broadcast is needed whenever the
// value eventually reaches more than one compute consumer, even through a
// chain of view-like ops.
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
// 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue);
if (useBroadcastOp)
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
else
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
}
// Use `HaltOp` instead of `YieldOp`
@@ -200,17 +340,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
rewriter.create<PimHaltOp>(computeOp.getLoc());
PimHaltOp::create(rewriter, computeOp.getLoc());
}
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
@@ -247,24 +387,31 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
}
});
}
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue);
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
}
}
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
@@ -273,9 +420,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc,
tensorType,
deviceTensor,
@@ -295,16 +443,19 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i);
if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
@@ -318,6 +469,9 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
@@ -351,12 +505,15 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
return success();
}
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front();
@@ -369,70 +526,74 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue;
if (useBroadcastOp)
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
blockArg.replaceAllUsesWith(receivedValue);
Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) {
SmallVector<Operation*> clonedChain;
Value currentValue = consumerValue;
while (currentValue != channelSourceOp) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || !isChannelUseChainOp(definingOp))
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
clonedChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
markOpToRemove(op);
}
replacementValue = cast<Value>(mapping.lookup(consumerValue));
}
assert(replacementValue.getType() == blockArg.getType()
&& "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue);
}
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses();
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
if (useBroadcastOp == false) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
continue;
}
if (!computeUser) {
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
if (!reshapeOp) {
resultUse.getOwner()->dump();
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
}
// The tensorType now becomes the one of the reshapeOp
channelTensorType = reshapeOp.getResult().getType();
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue;
}
// Remove the reshapeOp, so that the sourceOp has no users
operationsToRemove.push_back(reshapeOp);
if (!isChannelUseChainOp(owner))
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
markOpToRemove(owner);
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
self(owner->getResult(0), self);
}
}
};
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
}
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
void SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
for (auto it : llvm::enumerate(returnOp.getOperands())) {
Operation* returnOperand = it.value().getDefiningOp();
@@ -451,7 +612,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
}
}
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
@@ -461,15 +622,10 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user
auto receiveUses = receiveRes.getUses();
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one
@@ -480,6 +636,6 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
}
}
} // namespace pim
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir

View File

@@ -1,2 +1,2 @@
add_subdirectory(PIM)
add_subdirectory(Pim)
add_subdirectory(Spatial)

View File

@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,10 +1,13 @@
add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_onnx_mlir_library(PimOps
add_pim_library(PimOps
PimOps.hpp
PimOps.cpp
Transforms/PimBufferizableOpInterface.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
OMPimIncGen

View File

@@ -14,20 +14,13 @@ def PimDialect : Dialect {
let cppNamespace = "::onnx_mlir::pim";
}
// Base class for Pim dialect operations. This operation inherits from the
// base `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class PimOp<string mnemonic, list<Trait> traits = []> :
Op<PimDialect, mnemonic, traits>;
def PimTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
//===----------------------------------------------------------------------===//
// Communication operations
//===----------------------------------------------------------------------===//
// Communication
def PimSendOp: PimOp<"send", []> {
let arguments = (ins
@@ -63,9 +56,7 @@ def PimReceiveOp: PimOp<"receive", [DestinationStyleOpInterface]> {
}];
}
//===----------------------------------------------------------------------===//
// Core operations
//===----------------------------------------------------------------------===//
// Core
def PimCoreOp: PimOp<"core", [SingleBlock]> {
@@ -81,9 +72,7 @@ def PimCoreOp: PimOp<"core", [SingleBlock]> {
}];
}
//===----------------------------------------------------------------------===//
// Memory Operations
//===----------------------------------------------------------------------===//
// Memory
def PimConstantOp: PimOp<"constant", []> {
let description = [{
@@ -157,9 +146,62 @@ def PimMemCopyDevToHostOp: PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
}];
}
//===----------------------------------------------------------------------===//
// Core.Compute operations
//===----------------------------------------------------------------------===//
def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> {
let description = [{
Copy a memory region from and to the same memory
}];
let arguments = (ins
PimTensor: $dst,
PimTensor: $src,
I32Attr: $dstOffset,
I32Attr: $srcOffset,
I32Attr: $size
);
let results = (outs
PimTensor: $dstOut
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDstMutable();
}
}];
let assemblyFormat = [{
`(` $dst `,` $src `)` attr-dict `:` `(` type($dst) `,` type($src) `)` `->` type($dstOut)
}];
}
// Algebra
def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> {
let description = [{
Matrix transpose
}];
let arguments = (ins
PimTensor: $data,
I64ArrayAttr: $perms,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
let description = [{
@@ -181,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
@@ -205,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
}];
}
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
@@ -231,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
}];
}
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{
Element-wise max: c = max(a, b)
}];
@@ -245,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
@@ -286,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
);
}
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
Average all elements into a single one
}];
let arguments = (ins
PimTensor: $dividend,
PimTensor: $divisor,
PimTensor: $a,
PimTensor: $outBuf
);
@@ -317,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
);
}
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise exp: c = exp(a)
Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}];
let arguments = (ins

View File

@@ -10,7 +10,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -20,7 +20,7 @@ namespace pim {
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"
>();
}
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVMaxOp)
POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp)
POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp)
POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim
} // namespace onnx_mlir
@@ -45,5 +46,5 @@ POPULATE_DEPENDENCIES(PimVExpOp)
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc"

View File

@@ -12,7 +12,7 @@
#include <string>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc"
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc"

View File

@@ -0,0 +1,23 @@
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_pim_library(OMPimBufferization
PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp
Common.hpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
PimBufferizationIncGen
LINK_LIBS PUBLIC
OMPimCommon
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_GENERATED_INCLUDE_DIRS}
)

View File

@@ -0,0 +1,9 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
return builder.getI32IntegerAttr(sizeInBytes);
}

View File

@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
namespace onnx_mlir {
namespace pim {
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -1,11 +1,11 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
#include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace bufferization;
@@ -13,6 +13,26 @@ using namespace bufferization;
namespace onnx_mlir {
namespace pim {
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op,
@@ -77,6 +97,32 @@ struct MemCopyDevToHostOpInterface
}
};
struct TransposeOpBufferizeInterface
: DstBufferizableOpInterfaceExternalModel<TransposeOpBufferizeInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto transposeOp = cast<PimTransposeOp>(op);
auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state);
if (failed(dataOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt);
return success();
}
};
struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBufferizeInterface, PimVMMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -139,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
}
};
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -154,32 +201,39 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt))
return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
return success();
}
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
});
}

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
namespace pim {
void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,19 @@
#ifndef PIM_BUFFERIZATION
#define PIM_BUFFERIZATION
#ifndef OP_BASE
include "mlir/IR/PatternBase.td"
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def memrefCopyToPimMemCopyOp : Pat<
(CopyOp $src, $dst),
(PimMemCopyOp $dst, $src,
ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">,
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
(returnType $dst))
>;
#endif // PIM_BUFFERIZATION

View File

@@ -5,23 +5,43 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/raw_os_ostream.h"
#include "Common/PIMCommon.hpp"
#include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace pim {
namespace {
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; }
PimBufferizationPass() = default;
PimBufferizationPass(const PimBufferizationPass& pass) {}
void runOnOperation() final;
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
};
} // namespace
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
// Do One-Shot-Bufferization
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
@@ -30,7 +50,19 @@ void PimBufferizationPass::runOnOperation() {
signalPassFailure();
}
// Remove toTensor operations
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
// Remove toTensor operations: leave memrefs instead
moduleOp.walk([](bufferization::ToTensorOp toTensorOp) {
toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer());
toTensorOp.erase();
@@ -57,23 +89,22 @@ void PimBufferizationPass::runOnOperation() {
annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug
dumpModule(moduleOp, "pim_buf");
dumpModule(moduleOp, "pim1_buff");
}
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
&& !getGlobalOp->getUsers().empty();
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant());
getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx));
globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx));
markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp);
}
});
}
} // namespace pim
std::unique_ptr<Pass> createBufferizePimPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir

View File

@@ -1,15 +1,22 @@
add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_onnx_mlir_library(SpatialOps
add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
OMONNXIncGen
OMSpatialIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRBufferizationDialect
MLIRBufferizationTransforms
OMMlirDialects
OMONNXOps
OMPimCompilerOptions
PimOps
)

View File

@@ -24,8 +24,8 @@
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());

View File

@@ -17,8 +17,8 @@
#include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -34,11 +34,79 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
return memref::AllocOp::create(rewriter, loc, memrefResultType);
}
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
++usersIterator;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
++usersIterator;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
return failure();
}
Operation* otherUser = nullptr;
if (firstUser == op)
otherUser = secondUser;
else if (secondUser == op)
otherUser = firstUser;
else {
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
return failure();
}
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
return failure();
}
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
return failure();
}
return otherUser;
}
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
@@ -49,7 +117,7 @@ llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsR
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
@@ -119,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(*memref);
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
}
// TODO: Support addiction with more than 2 operands
@@ -134,7 +202,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
memrefOperands.push_back(outputTensor);
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -169,11 +237,13 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue =
rewriter
.create<ToTy>(
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
.getOutRes();
Value newValue = ToTy::create(rewriter,
op->getLoc(),
outputTensor.getType(),
cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -213,12 +283,12 @@ struct ChannelReceiveOpInterface
if (failed(srcCoreId))
return failure();
Value newValue = rewriter
.create<pim::PimReceiveOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOut();
replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -300,7 +370,8 @@ struct ChannelBroadcastReceiveOpInterface
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto outputType = cast<ShapedType>(outputTensor.getType());
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
@@ -323,13 +394,14 @@ struct ChannelBroadcastReceiveOpInterface
}
rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
outputTensor.getType(),
outputTensor,
bufferAllocation,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
bufferAllocation,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getDeviceDst());
@@ -356,7 +428,8 @@ struct ChannelBroadcastSendOpInterface
}
/*
* Turn the channel send to pim.send
* Turn the channel send into a device-to-host copy into the shared
* broadcast buffer that receive ops load from later.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
@@ -389,14 +462,25 @@ struct ChannelBroadcastSendOpInterface
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
}
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
pim::PimMemCopyDevToHostOp::create(rewriter,
op->getLoc(),
bufferAllocation.getType(),
bufferAllocation,
srcMemRef,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(op);
return success();
}
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
@@ -404,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
@@ -469,14 +551,15 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
weightIndices,
xKernelPositions,
yKernelPositions,
*inputBuffer,
outputTensor,
accumBuffer);
// Replace the operation with the bufferized value.
replaceOpWithBufferizedValues(rewriter, op, bufferized);
@@ -492,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
@@ -504,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
});
}

View File

@@ -4,14 +4,12 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,17 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRLinalgDialect
OMPimCommon
)

View File

@@ -1,9 +1,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir;

View File

@@ -1,6 +1,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Compiler/CompilerUtils.hpp"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -17,7 +18,10 @@ struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
: message(message) {}
MessagePass(const MessagePass& pass) {}
void runOnOperation() final { showCompilePhase(message); }
void runOnOperation() final {
llvm::outs() << message << "\n";
llvm::outs().flush();
}
private:
std::string message;

30
src/PIM/Pass/PIMPasses.h Normal file
View File

@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include <memory>
#include <string>
namespace onnx_mlir {
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createCountInstructionPass();
} // namespace onnx_mlir

View File

@@ -0,0 +1,121 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,41 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
mlir::Location loc,
mlir::MemRefType globalType,
mlir::DenseElementsAttr denseAttr,
llvm::StringRef nameStem,
mlir::IntegerAttr alignment = {});
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
llvm::ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth);
} // namespace onnx_mlir

View File

@@ -0,0 +1,53 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
populateConstantFoldingConstantPatterns(owningPatterns);
populateConstantFoldingSubviewPatterns(owningPatterns);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir

View File

@@ -0,0 +1,487 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty())
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
Attribute attr;
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
return failure();
return attr;
}
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
if (!initType || !initType.hasStaticShape())
return failure();
auto fillValue = getConstantMapYield(mapOp);
if (failed(fillValue))
return failure();
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp);
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
rewriter.eraseOp(mapOp);
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numElements = resultTensorType.getNumElements();
if (numElements < 0)
return failure();
Attribute fillValue;
SmallVector<ConstantSubviewCopy> copies;
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
SmallVector<Value> pendingAliases;
pendingAliases.push_back(allocOp.getResult());
while (!pendingAliases.empty()) {
Value alias = pendingAliases.pop_back_val();
for (Operation* user : alias.getUsers()) {
if (!visitedAliases.insert(user).second)
continue;
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
if (mapOp.getInit() != alias)
return failure();
auto maybeFillValue = getConstantMapYield(mapOp);
if (failed(maybeFillValue))
return failure();
if (fillValue && fillValue != *maybeFillValue)
return failure();
fillValue = *maybeFillValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
offsets.push_back(*staticOffset);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
strides.push_back(*staticStride);
}
for (Operation* subviewUser : subviewOp->getUsers()) {
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
if (copyOp.getTarget() != subviewOp.getResult())
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
if (failed(denseAttr))
return failure();
copies.push_back({*denseAttr, offsets, strides, copyOp});
continue;
}
return failure();
}
continue;
}
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
continue;
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
pendingAliases.push_back(castOp.getResult());
continue;
}
return failure();
}
}
if (!fillValue)
return failure();
SmallVector<Attribute> resultValues(numElements, fillValue);
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
});
for (const ConstantSubviewCopy& copy : copies) {
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
SmallVector<int64_t> sourceIndices =
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
SmallVector<int64_t> resultIndices;
resultIndices.reserve(sourceIndices.size());
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
resultIndices.push_back(offset + sourceIndex * stride);
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
resultValues[resultLinearIndex] = value;
}
}
return DenseElementsAttr::get(resultTensorType, resultValues);
}
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPerms().size());
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(),
resultType,
*transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
return success();
}
};
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
if (failed(foldedAttr))
return failure();
auto allocType = cast<MemRefType>(allocOp.getType());
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
SmallVector<Operation*> opsToErase;
SmallVector<memref::CastOp> castsToReplace;
bool allLiveUsersAreCoreOps = true;
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
opsToErase.push_back(user);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
castsToReplace.push_back(castOp);
continue;
}
if (!isa<pim::PimCoreOp>(user))
return failure();
}
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
})) {
allLiveUsersAreCoreOps = false;
}
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
})) {
return failure();
}
if (allLiveUsersAreCoreOps) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
for (memref::CastOp castOp : castsToReplace)
preservedUsers.insert(castOp);
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
for (memref::CastOp castOp : castsToReplace) {
rewriter.setInsertionPoint(castOp);
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
rewriter.replaceOp(castOp, replacementCast);
if (allLiveUsersAreCoreOps)
markWeightAlways(replacementCast.getDefiningOp());
}
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
rewriter.eraseOp(subviewUser);
if (op->use_empty())
rewriter.eraseOp(op);
}
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];
}
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
} // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantMemCpPattern>(patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,223 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
dstByteOffset,
srcByteOffset,
sliceBytes);
}
return success();
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDst(),
copyOp.getSrc(),
copyOp.getDstOffset(),
copyOp.getSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceDst(),
copyOp.getHostSrc(),
copyOp.getDeviceDstOffset(),
copyOp.getHostSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
return success();
}
};
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
if (subviewOp.use_empty())
return failure();
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
return failure();
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
if (failed(denseAttr))
return failure();
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
if (failed(subviewInfo))
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
auto resultMemRefType =
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
return success();
}
};
} // namespace
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -1,6 +1,6 @@
#include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp"
#include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
using namespace mlir;

View File

@@ -0,0 +1,134 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
if (isa<pim::PimHaltOp>(op))
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
}
if (hasFailure) {
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<MaterializeConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,194 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp,
memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
spatial::SpatChannelNewOp>(op);
}
static bool isCodegenAddressableValue(Value value) {
auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress))
return false;
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
VerificationPass() {}
VerificationPass(const VerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
hasFailure = true;
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp)))
hasFailure = true;
continue;
}
if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
hasFailure = true;
continue;
}
if (failed(verifyAddressOnlyHostOp(&op)))
hasFailure = true;
}
}
if (hasFailure)
signalPassFailure();
}
private:
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
hasFailure = true;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isCodegenAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (Operation& op : coreOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op))
continue;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc());
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc());
return success();
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isCodegenAddressableValue(source))
return success();
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -1,25 +0,0 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include <memory>
using namespace mlir;
namespace onnx_mlir {
std::unique_ptr<Pass> createONNXToSpatialPass();
std::unique_ptr<Pass> createSpatialToGraphvizPass();
std::unique_ptr<Pass> createSpatialToPIMPass();
std::unique_ptr<Pass> createBufferizePimPass();
std::unique_ptr<Pass> createEmitPimJsonPass();
std::unique_ptr<Pass> createMessagePass(std::string message);
std::unique_ptr<Pass> createCountInstructionPass();
} // namespace onnx_mlir

View File

@@ -4,6 +4,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -12,11 +13,11 @@
#include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp"
#define DEBUG_TYPE "PimAccelerator"
@@ -40,39 +41,41 @@ PimAccelerator::PimAccelerator()
acceleratorTargets.push_back(this);
}
PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module,
PassManager& pm,
void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
mlir::PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt);
}
void PimAccelerator::registerDialects(DialectRegistry& registry) const {
void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
registry.insert<bufferization::BufferizationDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<mlir::tosa::TosaDialect>();
registry.insert<mlir::bufferization::BufferizationDialect>();
registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry);
}
void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
registerPass(createONNXToSpatialPass);
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPIMPass);
registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass);
registerPass(createConstantFoldingPass);
registerPass(createMaterializeConstantsPass);
registerPass(createVerificationPass);
registerPass(createEmitPimJsonPass);
}
@@ -81,26 +84,26 @@ void PimAccelerator::configurePasses() const {
// TODO: This does nothing for now.
}
MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const {
mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types.
return nullptr;
}
void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const {
void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const {
target.addLegalDialect<pim::PimDialect>();
}
void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns,
TypeConverter& typeConverter,
MLIRContext* ctx) const {
void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
mlir::TypeConverter& typeConverter,
mlir::MLIRContext* ctx) const {
// TODO: Add patterns for conversion
}
void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {}
void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns,
LLVMTypeConverter& typeConverter,
MLIRContext* ctx) const {
void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
mlir::LLVMTypeConverter& typeConverter,
mlir::MLIRContext* ctx) const {
// We should not need this, since we offload it all to PIM.
}

View File

@@ -18,8 +18,6 @@ public:
PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator&) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance.
static PimAccelerator* getInstance();

Some files were not shown because too many files have changed in this diff Show More