Compare commits

6 Commits

Author SHA1 Message Date
NiccoloN a1b29dffe0 fix CI (hopefully)
Validate Operations / validate-operations (push) Failing after 1h2m33s
2026-03-23 19:27:53 +01:00
NiccoloN 661170a9aa reimplement pool lowering
add pool validation
align PIM ops/codegen/parser with the ISA
move constant materialization to MLIR
rename the PIM verification/materialization passes
better folded-constant handling
2026-03-23 19:14:50 +01:00
NiccoloN 461bdd808d replace helper-op cleanup with canonicalization
clean up PIM pattern naming
remove unused ValueMap.hpp
2026-03-23 17:13:54 +01:00
NiccoloN 50c545539b clean up PIM CMake
update README.md
2026-03-23 16:39:14 +01:00
NiccoloN 11916a2595 refactor Pim constant folding pass
share contiguous address resolution in PimCommon
group patterns in subdir for each pass with pattern files
2026-03-23 15:36:58 +01:00
NiccoloN 670d6ce94f extend operation support for conv and gemm
add more tests in validation
2026-03-23 14:46:08 +01:00
89 changed files with 2666 additions and 1141 deletions
@@ -0,0 +1,54 @@
name: Prepare MLIR Cache
description: Restore or build the cached MLIR toolchain.
inputs:
llvm-commit:
description: LLVM commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore MLIR cache
id: restore-mlir-cache
uses: actions/cache/restore@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Clone LLVM
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
cmake -S onnx-mlir/llvm-project/llvm -B onnx-mlir/llvm-project/build -G Ninja \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build onnx-mlir/llvm-project/build
- name: Save MLIR cache
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
@@ -0,0 +1,45 @@
name: Prepare Protobuf Cache
description: Restore or build the cached protobuf installation.
inputs:
protobuf-commit:
description: Protobuf tag or commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore protobuf cache
id: restore-protobuf-cache
uses: actions/cache/restore@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
- name: Install protobuf
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --depth 1 --branch ${{ inputs.protobuf-commit }} https://github.com/protocolbuffers/protobuf
cmake -S protobuf -B protobuf/build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build protobuf/build
sudo cmake --install protobuf/build
rm -rf protobuf
- name: Save protobuf cache
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
@@ -0,0 +1,26 @@
name: Restore Raptor Build Cache
description: Restore the cached raptor build directory for incremental builds.
inputs:
key:
description: Exact cache key to restore.
required: true
restore-keys:
description: Prefixes used to restore the most recent compatible cache.
required: false
outputs:
cache-hit:
description: Whether the exact cache key was restored.
value: ${{ steps.restore-raptor-build-cache.outputs.cache-hit }}
runs:
using: composite
steps:
- name: Restore raptor build cache
id: restore-raptor-build-cache
uses: actions/cache/restore@v4
with:
path: build
key: ${{ inputs.key }}
restore-keys: ${{ inputs.restore-keys }}
@@ -0,0 +1,16 @@
name: Save Raptor Build Cache
description: Save the raptor build directory after a successful incremental build.
inputs:
key:
description: Cache key used to save the build directory.
required: true
runs:
using: composite
steps:
- name: Save raptor build cache
uses: actions/cache/save@v4
with:
path: build
key: ${{ inputs.key }}
-50
View File
@@ -1,50 +0,0 @@
name: Build MLIR Cache
on:
workflow_call:
inputs:
llvm-commit:
required: true
type: string
jobs:
build-mlir:
runs-on: ubuntu-latest
steps:
- name: Cache MLIR build
id: cache-mlir
uses: actions/cache@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Install build dependencies
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
sudo apt update
sudo apt install -y cmake ninja-build build-essential
- name: Clone LLVM
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
mkdir -p onnx-mlir/llvm-project/build
cd onnx-mlir/llvm-project/build
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF
cmake --build .
+32 -46
View File
@@ -4,29 +4,14 @@ env:
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286 LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
PROTOBUF_COMMIT: v34.0 PROTOBUF_COMMIT: v34.0
CMAKE_VERSION: 4.3.0 CMAKE_VERSION: 4.3.0
MOLD_LINKER_FLAGS: -fuse-ld=mold
on: on:
push: push:
pull_request: pull_request:
jobs: jobs:
# Expose env vars as outputs so they can be passed to reusable workflows validate-operations:
config:
runs-on: ubuntu-latest
outputs:
llvm-commit: ${{ steps.set-vars.outputs.llvm-commit }}
steps:
- id: set-vars
run: echo "llvm-commit=$LLVM_COMMIT" >> "$GITHUB_OUTPUT"
build-mlir-cache:
needs: config
uses: ./.github/workflows/build_mlir_cache.yml
with:
llvm-commit: ${{ needs.config.outputs.llvm-commit }}
validate:
needs: build-mlir-cache
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@@ -39,7 +24,7 @@ jobs:
- name: Install system dependencies - name: Install system dependencies
run: | run: |
sudo apt update sudo apt update
sudo apt install -y ninja-build build-essential curl ca-certificates sudo apt install -y cmake ninja-build build-essential mold curl ca-certificates
- name: Install CMake - name: Install CMake
run: | run: |
@@ -58,27 +43,17 @@ jobs:
cmake --version cmake --version
which cmake which cmake
- name: Cache protobuf build - name: Prepare MLIR cache
id: cache-protobuf uses: ./.github/actions/prepare-mlir-cache
uses: actions/cache@v4
with: with:
path: | llvm-commit: ${{ env.LLVM_COMMIT }}
/usr/local/lib/libproto* mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ env.PROTOBUF_COMMIT }}
- name: Install protobuf - name: Prepare protobuf cache
if: steps.cache-protobuf.outputs.cache-hit != 'true' uses: ./.github/actions/prepare-protobuf-cache
run: | with:
git clone --depth 1 --branch ${{ env.PROTOBUF_COMMIT }} https://github.com/protocolbuffers/protobuf protobuf-commit: ${{ env.PROTOBUF_COMMIT }}
cd protobuf mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
ninja
sudo ninja install
cd ../..
rm -rf protobuf
- name: Register installed libraries - name: Register installed libraries
run: sudo ldconfig run: sudo ldconfig
@@ -94,23 +69,34 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: pip install numpy onnx colorama run: pip install numpy onnx colorama
- name: Restore MLIR cache - name: Restore raptor build cache
uses: actions/cache/restore@v4 id: restore-raptor-build-cache
uses: ./.github/actions/restore-raptor-build-cache
with: with:
path: onnx-mlir/llvm-project key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
key: mlir-${{ runner.os }}-${{ env.LLVM_COMMIT }} restore-keys: |
fail-on-cache-miss: true raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
raptor-build-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
- name: Build Raptor - name: Build Raptor
id: build-raptor
run: | run: |
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir -p build && cd build cmake -S . -B build -G Ninja \
cmake .. -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \ -DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR} -DMLIR_DIR=${MLIR_DIR} \
cmake --build . -DCMAKE_EXE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_SHARED_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_MODULE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}"
cmake --build build
- name: Save raptor build cache
if: steps.build-raptor.outcome == 'success' && steps.restore-raptor-build-cache.outputs.cache-hit != 'true'
uses: ./.github/actions/save-raptor-build-cache
with:
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
- name: Run validation - name: Run validation
run: | run: |
+9 -1
View File
@@ -29,7 +29,7 @@ Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
Moreover, if compiling with build type debug, it is also suggested to use Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already) mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it with: to reduce memory usage during linking. You can use it by setting the options:
``` ```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \ -DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" -DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
@@ -38,8 +38,16 @@ to reduce memory usage during linking. You can use it with:
### Raptor ### Raptor
Use the following commands to build Raptor. Use the following commands to build Raptor.
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor. Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
```
``` ```
git submodule update --init --recursive git submodule update --init --recursive
@@ -530,6 +530,7 @@ where
let r2_val = r2; let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported"); ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd); let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?; let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0]; let load1 = loads[0];
@@ -224,7 +224,21 @@ fn json_to_vvsub(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vvsub", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvsub, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vvdmul", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -306,7 +334,21 @@ fn json_to_vavg(
inst_data_builder: &mut InstructionDataBuilder, inst_data_builder: &mut InstructionDataBuilder,
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
todo!("Not present in the compiler"); let json = json.as_object().expect("Not an object");
assert_eq!("vavg", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vavg, inst_data_builder.build());
Ok(()) Ok(())
} }
@@ -358,7 +400,7 @@ fn json_to_vsigm(
json: &Value, json: &Value,
) -> Result<()> { ) -> Result<()> {
let json = json.as_object().expect("Not an object"); let json = json.as_object().expect("Not an object");
assert_eq!("vsigmoid", json_str!(json, "op")); assert_eq!("vsigm", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32; let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32; let rs1 = json_i64!(json, "rs1") as i32;
let len = json_i64!(json, "len") as i32; let len = json_i64!(json, "len") as i32;
+38 -17
View File
@@ -10,34 +10,54 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
add_subdirectory(Common) set(PIM_PUBLIC_INCLUDE_DIRS
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp
Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
Pass/PimConstantFoldingPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include ${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT} ${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT} ${PIM_SRC_ROOT}
${PIM_BIN_ROOT} ${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH} ${PIM_INCLUDE_PATH}
)
set(PIM_COMPILER_INCLUDE_DIRS
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_ACCEL_INCLUDE_DIRS
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_GENERATED_INCLUDE_DIRS
${PIM_INCLUDE_PATH}
)
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
endfunction()
add_subdirectory(Dialect)
add_subdirectory(Common)
add_subdirectory(Pass)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_pim_library(OMPIMAccel
PimAccelerator.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
onnx onnx
OMAccelerator OMAccelerator
OMPimCompilerUtils OMPimCompilerUtils
OMCompilerUtils OMPimPasses
OMONNXOps OMONNXOps
SpatialOps SpatialOps
PimOps PimOps
@@ -45,5 +65,6 @@ add_onnx_mlir_library(OMPIMAccel
OMSpatialToGraphviz OMSpatialToGraphviz
OMSpatialToPim OMSpatialToPim
OMPimCommon OMPimCommon
OMPimBufferization
MLIRTensorInferTypeOpInterfaceImpl MLIRTensorInferTypeOpInterfaceImpl
) )
+2 -8
View File
@@ -1,19 +1,13 @@
add_onnx_mlir_library(OMPimCommon add_pim_library(OMPimCommon
PimCommon.cpp PimCommon.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include ${PIM_PUBLIC_INCLUDE_DIRS}
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
onnx onnx
OMPimCompilerUtils
SpatialOps SpatialOps
PimOps PimOps
) )
+63
View File
@@ -1,3 +1,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <filesystem> #include <filesystem>
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
return true; return true;
} }
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
int64_t byteOffset = 0;
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = tiedOperand->get();
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|| llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset};
return failure();
}
}
} // namespace onnx_mlir } // namespace onnx_mlir
+7
View File
@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
std::string getOutputDir(); std::string getOutputDir();
void createDirectory(const std::string& directory); void createDirectory(const std::string& directory);
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides); llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
} // namespace onnx_mlir } // namespace onnx_mlir
-44
View File
@@ -1,44 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};
+12 -19
View File
@@ -1,44 +1,37 @@
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS) add_pim_library(OMPimCompilerOptions
add_onnx_mlir_library(OMPimCompilerOptions
PimCompilerOptions.cpp PimCompilerOptions.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT} ${PIM_COMPILER_INCLUDE_DIRS}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
${OMLibs}
OMCompilerOptions OMCompilerOptions
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT} ${PIM_ACCEL_INCLUDE_DIRS}
${PIM_ONNX_MLIR_BIN_ROOT}
) )
add_onnx_mlir_library(OMPimCompilerUtils add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp PimCompilerUtils.cpp
PimCodeGen.cpp PimCodeGen.cpp
../Pass/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT} ${PIM_COMPILER_INCLUDE_DIRS}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
${OMLibs}
OMCompilerUtils
OMPimCompilerOptions OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
OMCompilerPasses OMCompilerPasses
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT} ${PIM_ACCEL_INCLUDE_DIRS}
${PIM_ONNX_MLIR_BIN_ROOT}
) )
+142 -62
View File
@@ -14,12 +14,11 @@
#include <cmath> #include <cmath>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
@@ -86,48 +85,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
} }
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
size_t offset = 0; auto resolvedAddress = resolveContiguousAddress(value);
while (true) { if (failed(resolvedAddress)) {
auto definingOp = value.getDefiningOp(); errs() << "Failed to resolve contiguous address for value: ";
if (!definingOp)
break;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
break;
value = tiedOperand->get();
}
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto source = subviewDefiningOp.getSource();
auto srcShape = source.getType().getShape();
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
size_t localOffset = subviewOffsets[i];
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
localOffset *= subviewSizes[j];
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
}
value = source;
}
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
}
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
}
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
}
else
break;
}
auto iter = memEntriesMap.find(value);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
value.print(errs()); value.print(errs());
errs() << "\n"; errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) { if (auto* definingOp = value.getDefiningOp()) {
@@ -135,10 +95,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
definingOp->print(errs()); definingOp->print(errs());
errs() << "\n"; errs() << "\n";
} }
llvm_unreachable("Failed to resolve contiguous address");
}
auto iter = memEntriesMap.find(resolvedAddress->base);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry"); llvm_unreachable("Missing mem entry");
} }
return iter->second.address + offset; return iter->second.address + resolvedAddress->byteOffset;
} }
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
@@ -264,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
} }
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const { static size_t getValueSizeInBytes(mlir::Value value) {
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf()); auto type = cast<ShapedType>(value.getType());
auto aAddr = memory.getValueAddress(vaddOp.getA()); return type.getNumElements() * type.getElementTypeBitWidth() / 8;
auto bAddr = memory.getValueAddress(vaddOp.getB()); }
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType()); void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8; auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvaddOp.getA());
auto bAddr = memory.getValueAddress(vvaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json; json::Object json;
json["op"] = "vvadd"; json["op"] = "vvadd";
@@ -279,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
json["rs1"] = 1; json["rs1"] = 1;
json["rs2"] = 2; json["rs2"] = 2;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = totalBytes; json["len"] = getValueSizeInBytes(vvaddOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const { void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf()); auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
auto aAddr = memory.getValueAddress(vmaxOp.getA()); auto aAddr = memory.getValueAddress(vvsubOp.getA());
auto bAddr = memory.getValueAddress(vmaxOp.getB()); auto bAddr = memory.getValueAddress(vvsubOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvsub";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmulOp.getA());
auto bAddr = memory.getValueAddress(vvmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0); setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json; json::Object json;
@@ -295,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
json["rs1"] = 1; json["rs1"] = 1;
json["rs2"] = 2; json["rs2"] = 2;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvdmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
auto aAddr = memory.getValueAddress(vavgOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -308,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
json["rd"] = 0; json["rd"] = 0;
json["rs1"] = 1; json["rs1"] = 1;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
auto aAddr = memory.getValueAddress(vtanhOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
auto aAddr = memory.getValueAddress(vsigmOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getA());
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -365,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
vaddJson["rs1"] = 1; vaddJson["rs1"] = 1;
vaddJson["rs2"] = 2; vaddJson["rs2"] = 2;
vaddJson["offset"] = createEmptyOffset(); vaddJson["offset"] = createEmptyOffset();
vaddJson["len"] = 32 * outChannels;
emitInstruction(std::move(vaddJson)); emitInstruction(std::move(vaddJson));
} }
} }
@@ -506,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp); coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op)) else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp); coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op)) else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVMaxOp(vmaxOp); coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op)) else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp); coreCodeGen.codeGenVReluOp(vreluOp);
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) { else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (isa<pim::PimSumOp>(op)) {
// TODO: Implement somehow? // TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation"); op.emitWarning("Operation is not yet supported in code generation");
continue; continue;
+9 -4
View File
@@ -3,8 +3,7 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "Common/ValueMap.hpp" #include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -91,9 +90,15 @@ public:
template <typename MVMTy> template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix); void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenVAddOp(pim::PimVAddOp vaddOp) const; void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
}; };
+3 -3
View File
@@ -5,7 +5,6 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#define DEBUG_TYPE "PimCompilerUtils" #define DEBUG_TYPE "PimCompilerUtils"
@@ -48,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass()); pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimHostVerificationPass()); pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createMessagePass("Pim host verified")); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted")); pm.addPass(createMessagePass("Pim json code emitted"));
@@ -2,23 +2,27 @@ set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
Math/Gemm.cpp Patterns/Math/Gemm.cpp
Math/Conv.cpp Patterns/Math/Conv.cpp
NN/Pooling.cpp Patterns/Math/MatMul.cpp
NN/ReduceMean.cpp Patterns/NN/Pool.cpp
Tensor/ONNXConcatToTensorConcat.cpp Patterns/NN/ReduceMean.cpp
Tensor/RemoveUnusedHelperOps.cpp Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp
Utils/SpatialReducer.cpp Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp Utils/AnnotateReplication.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS DEPENDS
ONNXToSpatialIncGen ONNXToSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
OMPimCompilerOptions OMPimCompilerOptions
OMONNXOps OMONNXOps
@@ -26,5 +30,5 @@ add_onnx_mlir_library(OMONNXToSpatial
OMPimCommon OMPimCommon
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_GENERATED_INCLUDE_DIRS}
) )
@@ -15,7 +15,7 @@
#include <optional> #include <optional>
#include <utility> #include <utility>
#include "ONNXToSpatialCommon.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,427 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1)
return valuesToReduce[0];
if (preprocess) {
for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
}
// It is possible that `valuesToReduce` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& valToReduce : valuesToReduce) {
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
// } else {
// // Otherwise it is a block argument,
// computeOp->getBlock()->getParentOp();
// }
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
auto it = lastValueForCompute.find(computeOp);
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
else
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
lastValueForCompute[computeOp] = valToReduce;
}
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second);
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the valuesToReduce list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
while (valuesToReduceRef.size() > 1) {
SmallVector<Value> nextValuesToReduce;
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
auto firstValue = valuesToReduceRef[i];
auto secondValue = valuesToReduceRef[i + 1];
auto firstCompute = firstValue.getParentBlock()->getParentOp();
auto secondCompute = secondValue.getParentBlock()->getParentOp();
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstValue, secondValue);
std::swap(firstCompute, secondCompute);
}
// 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute);
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue);
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
Value reduced = reduce(receivedValue, secondValue);
nextValuesToReduce.push_back(reduced);
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back());
// Replace the inputOps list with the new one.
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
if (postprocess) {
rewriter.setInsertionPointAfterValue(finalValue);
finalValue = postprocess(finalValue);
}
return finalValue;
}
template <typename PoolOp>
bool hasPostProcessPoolingWindow() {
return false;
}
template <>
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = getImageHeight(xShape);
size_t input_w = getImageWidth(xShape);
size_t output_h = getImageHeight(yShape);
size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt =
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
// If some input tiles come from the func.func operands, load
// them into a computeOp and yield them
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
tileLoc,
extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0);
}
}
}
}
// 2: Tile the output tensor
// Output tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
// Each output pixel tile is computed by pooling a window of input
// pixel tiles
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
}
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
}
else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool,
rewriter,
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
nullptr,
postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced =
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel =
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
}
}
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
auto outputTile = outputTiles[outTile][outX][outY];
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
return success();
}
};
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
// ONNXMatMulOp to ONNXGemmOp patterns // ONNXMatMulOp to ONNXGemmOp patterns
def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat< def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C, (ONNXGemmOp $A, $B, $C,
@@ -22,7 +26,8 @@ def matMulAddToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
) ),
[(IsRank2Result $matmulres)]
>; >;
def matMulToGemmPattern : Pat< def matMulToGemmPattern : Pat<
@@ -34,7 +39,8 @@ def matMulToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
) ),
[(IsRank2Result $matmulres)]
>; >;
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
@@ -1,8 +1,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
@@ -12,7 +13,7 @@
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -56,6 +57,7 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx); mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx); mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx); mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
@@ -74,7 +76,9 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>(); target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
});
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>(); target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>(); target.addIllegalOp<ONNXLRNOp>();
@@ -83,13 +87,15 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXConcatOp>(); target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXSoftmaxOp>(); target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>(); target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXReshapeOp>();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx); patterns.add<removeLRNPattern>(ctx);
populateConvOpPatterns(patterns, ctx); populateConvOpPatterns(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx); populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx);
@@ -113,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
} }
} }
// Remove trailing "helper ops" i.e. concat,img_concat,reshape. PassManager cleanupPM(ctx);
RewritePatternSet removeUnusedHelperOpsPatterns(ctx); cleanupPM.addPass(createCanonicalizerPass());
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
@@ -7,13 +7,15 @@ namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
}); });
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none // Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp()); bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC; Value gemmC;
if (hasB) { if (hasB)
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType()); gemmC = b;
gemmC = tensor::ExpandShapeOp::create(rewriter,
loc,
biasType,
b,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
else else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
@@ -11,7 +11,7 @@
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -23,6 +23,38 @@ namespace {
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
static FailureOr<Value> materializeScaledConstantTensor(Value value,
float factor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (factor == 1.0f)
return value;
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> { struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
if (numOutRows <= 1) if (numOutRows <= 1)
return failure(); return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
RankedTensorType cType = nullptr; RankedTensorType cType = nullptr;
bool cHasNumOutRows = false; bool cHasNumOutRows = false;
if (hasC) { if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType()); cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows; cHasNumOutRows = cType.getDimSize(0) == numOutRows;
} }
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
aSlice, aSlice,
b, b,
cSlice, cSlice,
gemmOp.getAlphaAttr(), rewriter.getF32FloatAttr(1.0f),
gemmOp.getBetaAttr(), rewriter.getF32FloatAttr(1.0f),
gemmOp.getTransAAttr(), gemmOp.getTransAAttr(),
gemmOp.getTransBAttr()); gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY()); gemvOps.push_back(gemvOp.getY());
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) { if (hasC) {
cType = cast<RankedTensorType>(c.getType()); cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
} }
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
} }
if (alpha != 1.0f) { if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType()); auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); if (failed(scaledB))
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue); return failure();
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor); b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
alpha = 1.0f;
} }
if (hasC && beta != 1.0f) { if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType()); auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); if (failed(scaledC))
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue); return failure();
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor); c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
beta = 1.0f;
} }
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure();
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
auto* concatBlock = new Block();
for (Value gemmRow : gemmRows)
concatBlock->addArgument(gemmRow.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
Value result = ONNXTransposeOp::create(
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();
}
};
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
}
template <typename PoolOp>
struct PoolToSpatialCompute;
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
using OpConversionPattern<PoolOp>::OpConversionPattern;
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = poolOp.getLoc();
Value x = adaptor.getX();
auto xType = dyn_cast<RankedTensorType>(x.getType());
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
if (xType.getRank() != 4 || outType.getRank() != 4)
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
ArrayAttr kernelAttr = poolOp.getKernelShape();
if (!kernelAttr || kernelAttr.size() != 2)
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
const int64_t batchSize = xType.getDimSize(0);
const int64_t channels = xType.getDimSize(1);
const int64_t inputHeight = xType.getDimSize(2);
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
int64_t padBottom = 0;
int64_t padRight = 0;
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
if (autoPad == "SAME_UPPER") {
padTop = totalPadH / 2;
padBottom = totalPadH - padTop;
padLeft = totalPadW / 2;
padRight = totalPadW - padLeft;
}
else {
padBottom = totalPadH / 2;
padTop = totalPadH - padBottom;
padRight = totalPadW / 2;
padLeft = totalPadW - padRight;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
auto* computeBlock = new Block();
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.replaceOp(poolOp, computeOp.getResult(0));
return success();
}
};
template <>
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
template <>
struct PoolToSpatialCompute<ONNXAveragePoolOp>
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}
} // namespace onnx_mlir
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -1,15 +1,15 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> { struct Concat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext* ctx) Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {} : OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
}; };
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx); patterns.insert<Concat>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,121 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(sourceIdx);
while (sourceProduct != resultProduct) {
if (sourceProduct > resultProduct)
return false;
sourceIdx++;
if (sourceIdx >= sourceShape.size())
return false;
group.push_back(sourceIdx);
sourceProduct *= sourceShape[sourceIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(resultIdx);
while (resultProduct != sourceProduct) {
if (resultProduct > sourceProduct)
return false;
resultIdx++;
if (resultIdx >= resultShape.size())
return false;
group.push_back(resultIdx);
resultProduct *= resultShape[resultIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
ONNXReshapeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
return failure();
if (sourceType == resultType) {
rewriter.replaceOp(reshapeOp, adaptor.getData());
return success();
}
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
return failure();
}
};
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Reshape>(ctx);
}
} // namespace onnx_mlir
@@ -1,35 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir
@@ -1,7 +1,7 @@
#include <queue> #include <queue>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -7,7 +7,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -1,14 +1,17 @@
add_onnx_mlir_rewriter(SpatialToGraphviz) add_onnx_mlir_rewriter(SpatialToGraphviz)
add_onnx_mlir_library(OMSpatialToGraphviz add_pim_library(OMSpatialToGraphviz
SpatialToGraphviz.cpp SpatialToGraphviz.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
OMPimCommon OMPimCommon
OMONNXOps OMONNXOps
SpatialOps SpatialOps
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_GENERATED_INCLUDE_DIRS}
) )
@@ -2,19 +2,22 @@ set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen) add_public_tablegen_target(SpatialToPimIncGen)
add_onnx_mlir_library(OMSpatialToPim add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
SpatialToPimCommon.cpp Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS DEPENDS
SpatialToPimIncGen SpatialToPimIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
OMPimCommon OMPimCommon
SpatialOps SpatialOps
PimOps PimOps
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_GENERATED_INCLUDE_DIRS}
) )
@@ -5,7 +5,7 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include "SpatialToPimCommon.hpp" #include "Common.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimVAddOp : Pat< def spatToPimVVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b), (SpatVAddOp:$srcOpRes $a, $b),
(PimVAddOp $a, $b, (PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMulOp : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMaxOp : Pat<
(SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -18,9 +19,9 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -56,10 +57,8 @@ private:
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value channelSourceOp, void
spatial::SpatChannelNewOp& channel, addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
bool useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp, Value channelSourceOp,
@@ -79,8 +78,31 @@ private:
} // namespace } // namespace
static bool isChannelUseChainOp(Operation* op) { static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>( return isa<tensor::ExtractSliceOp,
op); tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
} }
static size_t countComputeLeafUsers(Value value) { static size_t countComputeLeafUsers(Value value) {
@@ -175,35 +197,67 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty()) if (result.use_empty())
continue; continue;
auto yieldType = cast<TensorType>(yieldValue.getType()); auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses(); auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses); auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) { if (numResultUses == 1) {
OpOperand& resultUse = *resultUses.begin(); OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner(); Operation* resultUser = resultUse.getOwner();
if (isChannelUseChainOp(resultUser)) {
SmallVector<Operation*> returnChain;
Value chainedValue = result;
Operation* chainUser = resultUser;
while (isChannelUseChainOp(chainUser)) {
returnChain.push_back(chainUser);
auto chainUses = chainUser->getResult(0).getUses();
if (rangeLength(chainUses) != 1)
break;
chainedValue = chainUser->getResult(0);
chainUser = chainUses.begin()->getOwner();
}
if (isa<func::ReturnOp>(chainUser)) {
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
rewriter.setInsertionPoint(yieldOp);
IRMapping mapping;
mapping.map(result, yieldValue);
Value storedValue = yieldValue;
for (Operation* op : returnChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
storedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
markOpToRemove(op);
}
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
storedValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
continue;
}
}
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0; size_t offset = 0;
@@ -475,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
receivedValue = receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
Value replacementValue = receivedValue; Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) { if (consumerValue != channelSourceOp) {
@@ -493,6 +548,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
IRMapping mapping; IRMapping mapping;
mapping.map(channelSourceOp, receivedValue); mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) { for (Operation* op : llvm::reverse(clonedChain)) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping); Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult); mapping.map(originalResult, newResult);
@@ -502,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
replacementValue = cast<Value>(mapping.lookup(consumerValue)); replacementValue = cast<Value>(mapping.lookup(consumerValue));
} }
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type"); assert(replacementValue.getType() == blockArg.getType()
&& "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue); blockArg.replaceAllUsesWith(replacementValue);
} }
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
namespace spatial {
// TODO: Add here eventual patterns
}
} // namespace onnx_mlir
+3 -1
View File
@@ -3,10 +3,12 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization) add_subdirectory(Transforms/Bufferization)
add_onnx_mlir_library(PimOps add_pim_library(PimOps
PimOps.hpp PimOps.hpp
PimOps.cpp PimOps.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS DEPENDS
OMPimIncGen OMPimIncGen
+100 -8
View File
@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> { def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
Element-wise addition: c = a + b Element-wise addition: c = a + b
}]; }];
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{ let description = [{
Element-wise max: c = max(a, b) Element-wise max: c = max(a, b)
}]; }];
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
let results = (outs let results = (outs
PimTensor: $outRes PimTensor: $outRes
); );
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
} }
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
); );
} }
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{ let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience) Average all elements into a single one
}]; }];
let arguments = (ins let arguments = (ins
PimTensor: $dividend, PimTensor: $a,
PimTensor: $divisor,
PimTensor: $outBuf PimTensor: $outBuf
); );
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
); );
} }
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> { def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{ let description = [{
Element-wise exp: c = exp(a) Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}]; }];
let arguments = (ins let arguments = (ins
+4 -3
View File
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
registerDependenciesFn(this->getOutBuf(), this->getResult()); \ registerDependenciesFn(this->getOutBuf(), this->getResult()); \
} }
POPULATE_DEPENDENCIES(PimVMaxOp) POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp) POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp) POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp) POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp) POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp) POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -2,13 +2,15 @@ set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen) add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization add_pim_library(OMPimBufferization
PimBufferizationPass.cpp PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp OpBufferizationInterfaces.cpp
Common.hpp Common.hpp
Common.cpp Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS DEPENDS
PimBufferizationIncGen PimBufferizationIncGen
@@ -17,5 +19,5 @@ add_onnx_mlir_library(OMPimBufferization
PimOps PimOps
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_GENERATED_INCLUDE_DIRS}
) )
@@ -4,6 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -12,6 +13,26 @@ using namespace bufferization;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
} }
}; };
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> { template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
RewriterBase& rewriter, RewriterBase& rewriter,
const BufferizationOptions& options, const BufferizationOptions& options,
BufferizationState& state) const { BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op); auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt)) if (failed(aOpt))
return failure(); return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state); auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt)) if (failed(bOpt))
return failure(); return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state); auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt)) if (failed(outBufOpt))
return failure(); return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt); Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
return success(); return success();
} }
}; };
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx); PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx); PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx); PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
}); });
} }
+8 -2
View File
@@ -1,16 +1,22 @@
add_onnx_mlir_dialect(Spatial spat) add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td) add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
add_onnx_mlir_library(SpatialOps
SpatialOps.cpp SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp Transforms/SpatialBufferizableOpInterface.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS DEPENDS
OMONNXIncGen OMONNXIncGen
OMSpatialIncGen OMSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
MLIRBufferizationDialect
MLIRBufferizationTransforms
OMMlirDialects OMMlirDialects
OMONNXOps
OMPimCompilerOptions
PimOps
) )
+1 -1
View File
@@ -24,7 +24,7 @@
#include <cstdint> #include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -37,8 +37,76 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
return memref::AllocOp::create(rewriter, loc, memrefResultType); return memref::AllocOp::create(rewriter, loc, memrefResultType);
} }
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
++usersIterator;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
++usersIterator;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
return failure();
}
Operation* otherUser = nullptr;
if (firstUser == op)
otherUser = secondUser;
else if (secondUser == op)
otherUser = firstUser;
else {
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
return failure();
}
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
return failure();
}
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
return failure();
}
return otherUser;
}
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other // This function requires the existence of ChannelNewOp and the other
@@ -49,7 +117,7 @@ llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsR
if (precomputedOtherCoreId) if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt(); return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter); auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
if (failed(notOpUserOpt)) if (failed(notOpUserOpt))
return failure(); return failure();
Operation* notOpUser = *notOpUserOpt; Operation* notOpUser = *notOpUserOpt;
@@ -119,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
auto memref = getBuffer(rewriter, operand, options, state); auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref)) if (failed(memref))
return failure(); return failure();
memrefOperands.push_back(*memref); memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
} }
// TODO: Support addiction with more than 2 operands // TODO: Support addiction with more than 2 operands
@@ -412,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
}; };
struct VAddOpInterfaceFromTemplate struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {}; : VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {}; struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
@@ -420,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {}; struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {}; struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation. // Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> { struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
@@ -509,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx); SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx); SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx); SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx); SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx); SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx); SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
@@ -521,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {}; struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {}; struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) { void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx); ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx); ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
}); });
} }
+16
View File
@@ -0,0 +1,16 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
PimConstantFolding/Common.cpp
PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp
PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRLinalgDialect
OMPimCommon
)
-1
View File
@@ -3,7 +3,6 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir; using namespace mlir;
+6 -2
View File
@@ -1,6 +1,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/Compiler/CompilerUtils.hpp" #include "llvm/Support/raw_ostream.h"
using namespace mlir; using namespace mlir;
@@ -17,7 +18,10 @@ struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
: message(message) {} : message(message) {}
MessagePass(const MessagePass& pass) {} MessagePass(const MessagePass& pass) {}
void runOnOperation() final { showCompilePhase(message); } void runOnOperation() final {
llvm::outs() << message << "\n";
llvm::outs().flush();
}
private: private:
std::string message; std::string message;
+121
View File
@@ -0,0 +1,121 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
} // namespace onnx_mlir
@@ -0,0 +1,41 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
mlir::Location loc,
mlir::MemRefType globalType,
mlir::DenseElementsAttr denseAttr,
llvm::StringRef nameStem,
mlir::IntegerAttr alignment = {});
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
llvm::ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth);
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,19 +1,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "../Common.hpp"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "../Patterns.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -21,55 +13,14 @@
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value stripMemRefCasts(Value value) { struct ConstantSubviewCopy {
while (auto castOp = value.getDefiningOp<memref::CastOp>()) DenseElementsAttr source;
value = castOp.getSource(); SmallVector<int64_t> offsets;
return value; SmallVector<int64_t> strides;
} Operation* copyOp = nullptr;
};
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment = {}) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) { static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType()); auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
@@ -126,13 +77,6 @@ static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr den
return DenseElementsAttr::get(transposedType, transposedValues); return DenseElementsAttr::get(transposedType, transposedValues);
} }
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) { static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty()) if (!mapOp.getInputs().empty())
return failure(); return failure();
@@ -176,151 +120,13 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(coreOp); rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter, rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.eraseOp(mapOp); rewriter.eraseOp(mapOp);
return success(); return success();
} }
}; };
struct StaticSubviewInfo {
Value source;
SmallVector<int64_t> sourceShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
};
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
static int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (copyOp.getSize() != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = copyOp.getSrcOffset()
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = copyOp.getDstOffset()
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : copyOp.getDst(),
splitSrc ? srcSubview->source : copyOp.getSrc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) { static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType()); auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape()) if (!allocType || !allocType.hasStaticShape())
@@ -473,17 +279,15 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
if (!llvm::equal(transposedShape, resultType.getShape())) if (!llvm::equal(transposedShape, resultType.getShape()))
return failure(); return failure();
MemRefType globalType = resultType;
auto newGlobal = createFoldedGlobal(moduleOp, auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(), transposeOp.getLoc(),
globalType, resultType,
*transposedAttr, *transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose", sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr()); sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp); rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName()); auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight = bool isAlwaysWeight =
!transposeOp->getUsers().empty() !transposeOp->getUsers().empty()
@@ -578,41 +382,106 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
} }
}; };
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> { struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) using OpRewritePattern::OpRewritePattern;
StringRef getArgument() const override { return "pim-constant-folding-pass"; } LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
LogicalResult initialize(MLIRContext* context) override { auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
RewritePatternSet owningPatterns(context); if (!allocOp)
for (auto* dialect : context->getLoadedDialects()) return failure();
dialect->getCanonicalizationPatterns(owningPatterns); auto allocType = dyn_cast<MemRefType>(allocOp.getType());
for (RegisteredOperationName op : context->getRegisteredOperations()) if (!allocType || !allocType.hasStaticShape())
op.getCanonicalizationPatterns(owningPatterns, context); return failure();
owningPatterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>( if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
context); return failure();
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];
}
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success(); return success();
} }
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
}; };
} // namespace } // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); } void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantMemCpPattern>(patterns.getContext());
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,223 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
dstByteOffset,
srcByteOffset,
sliceBytes);
}
return success();
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDst(),
copyOp.getSrc(),
copyOp.getDstOffset(),
copyOp.getSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceDst(),
copyOp.getHostSrc(),
copyOp.getDeviceDstOffset(),
copyOp.getHostSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
return success();
}
};
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
if (subviewOp.use_empty())
return failure();
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
return failure();
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
if (failed(denseAttr))
return failure();
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
if (failed(subviewInfo))
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
auto resultMemRefType =
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
return success();
}
};
} // namespace
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,53 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
populateConstantFoldingConstantPatterns(owningPatterns);
populateConstantFoldingSubviewPatterns(owningPatterns);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
} // namespace onnx_mlir
@@ -0,0 +1,135 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
if (isa<pim::PimHaltOp>(op))
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
}
if (hasFailure) {
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
}
} // namespace onnx_mlir
+3 -1
View File
@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostVerificationPass(); std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
@@ -26,49 +27,32 @@ static bool isAddressOnlyHostOp(Operation* op) {
spatial::SpatChannelNewOp>(op); spatial::SpatChannelNewOp>(op);
} }
static bool isHostAddressableValue(Value value) { static bool isCodegenAddressableValue(Value value) {
while (true) { auto resolvedAddress = resolveContiguousAddress(value);
if (auto blockArg = dyn_cast<BlockArgument>(value)) if (failed(resolvedAddress))
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false; return false;
return isa<BlockArgument>(resolvedAddress->base)
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp)) || isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
return true;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return false;
}
} }
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
StringRef getArgument() const override { return "verify-pim-host-pass"; } struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR"; return "Verify that bufferized PIM IR contains only explicit host/device transfers";
} }
PimHostVerificationPass() {} PimVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {} PimVerificationPass(const PimVerificationPass& pass) {}
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -80,7 +64,7 @@ struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationP
for (Operation& op : funcOp.getBody().front().getOperations()) { for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) { if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp))) if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
hasFailure = true; hasFailure = true;
continue; continue;
} }
@@ -139,14 +123,49 @@ private:
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false; bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isHostAddressableValue(operand)) { if (!isCodegenAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage"; returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
hasFailure = true; hasFailure = true;
} }
} }
return success(!hasFailure); return success(!hasFailure);
} }
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (Operation& op : coreOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op))
continue;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) { static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op)) if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource()); return verifyAddressOnlySource(op, subviewOp.getSource());
@@ -160,16 +179,16 @@ private:
} }
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isHostAddressableValue(source)) if (isCodegenAddressableValue(source))
return success(); return success();
op->emitOpError("depends on a value that still requires host-side execution"); op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
return failure(); return failure();
} }
}; };
} // namespace } // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); } std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
+2 -1
View File
@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass); registerPass(createPimConstantFoldingPass);
registerPass(createPimHostVerificationPass); registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }
+59
View File
@@ -0,0 +1,59 @@
# Validation Operations
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
Generated tests can be regenerated with:
```
python3 validation/operations/gen_tests.py
```
## Conv
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|------|-----------|-------|--------|--------|--------|---------|------|-------|
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
## Pool
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|------|-----------|-------|--------|--------|--------|---------|-------|
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
## Gemm
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
## Gemv
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|------|-----------|-------|------------|--------|------|-------|
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+364
View File
@@ -0,0 +1,364 @@
#!/usr/bin/env python3
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
from pathlib import Path
OPERATIONS_DIR = Path(__file__).parent
def save_model(model, directory, filename):
"""Save an ONNX model, creating the directory if needed."""
d = OPERATIONS_DIR / directory
d.mkdir(parents=True, exist_ok=True)
path = d / filename
onnx.checker.check_model(model)
onnx.save(model, str(path))
print(f" {path.relative_to(OPERATIONS_DIR)}")
# ---------------------------------------------------------------------------
# GEMM tests
# ---------------------------------------------------------------------------
def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
def gemm_with_bias():
"""GEMM with bias: Y = A @ W + C."""
B, K, N = 4, 128, 128
rng = np.random.default_rng(43)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
def gemm_transB():
"""GEMM with transB=1: Y = A @ W^T."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(44)
# W stored as [N, K], transposed during computation
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB", "gemm_transB.onnx")
def gemm_alpha_beta():
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
B, K, N = 4, 64, 64
rng = np.random.default_rng(45)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
def gemm_small():
"""Small GEMM: [2, 8] @ [8, 4]."""
B, K, N = 2, 8, 4
rng = np.random.default_rng(46)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/small", "gemm_small.onnx")
def gemm_large():
"""Larger GEMM: [8, 256] @ [256, 128]."""
B, K, N = 8, 256, 128
rng = np.random.default_rng(47)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/large", "gemm_large.onnx")
def gemm_transB_with_bias():
"""GEMM with transB and bias: Y = A @ W^T + C."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(48)
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
# ---------------------------------------------------------------------------
# Conv tests
# ---------------------------------------------------------------------------
def conv_3x3_kernel():
"""Conv with 3x3 kernel, no padding."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
def conv_stride2():
"""Conv with 3x3 kernel and stride 2."""
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
W = numpy_helper.from_array(
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
def conv_multi_channel():
"""Conv with multiple input and output channels."""
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
def conv_1x1():
"""1x1 pointwise convolution (channel mixing)."""
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
def conv_same_padding_3x3():
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
W = numpy_helper.from_array(
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
def conv_explicit_padding():
"""Conv 3x3 with explicit asymmetric padding."""
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
def conv_with_bias_3x3():
"""Conv 3x3 with bias."""
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
rng = np.random.default_rng(56)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
def conv_batch_2():
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
rng = np.random.default_rng(57)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
def conv_large_spatial():
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
W = numpy_helper.from_array(
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
# ---------------------------------------------------------------------------
# Pooling tests
# ---------------------------------------------------------------------------
def maxpool_basic():
"""MaxPool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
def maxpool_stride2_multichannel():
"""MaxPool 2x2 with stride 2 on multiple channels."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
def maxpool_same_upper():
"""MaxPool 3x3 with SAME_UPPER padding."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
def avgpool_basic():
"""AveragePool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
def avgpool_explicit_padding():
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
def avgpool_include_pad():
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
def maxpool_after_conv():
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
rng = np.random.default_rng(59)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("Generating GEMM tests:")
gemm_non_square()
gemm_with_bias()
gemm_transB()
gemm_alpha_beta()
gemm_small()
gemm_large()
gemm_transB_with_bias()
print("\nGenerating Conv tests:")
conv_3x3_kernel()
conv_stride2()
conv_multi_channel()
conv_1x1()
conv_same_padding_3x3()
conv_explicit_padding()
conv_with_bias_3x3()
conv_batch_2()
conv_large_spatial()
print("\nGenerating Pooling tests:")
maxpool_basic()
maxpool_stride2_multichannel()
maxpool_same_upper()
avgpool_basic()
avgpool_explicit_padding()
avgpool_include_pad()
maxpool_after_conv()
print("\nDone.")