Compare commits

6 Commits

Author SHA1 Message Date
NiccoloN a1b29dffe0 fix CI (hopefully)
Validate Operations / validate-operations (push) Failing after 1h2m33s
2026-03-23 19:27:53 +01:00
NiccoloN 661170a9aa reimplement pool lowering
add pool validation
align PIM ops/codegen/parser with the ISA
move constant materialization to MLIR
rename the PIM verification/materialization passes
better folded-constant handling
2026-03-23 19:14:50 +01:00
NiccoloN 461bdd808d replace helper-op cleanup with canonicalization
clean up PIM pattern naming
remove unused ValueMap.hpp
2026-03-23 17:13:54 +01:00
NiccoloN 50c545539b clean up PIM CMake
update README.md
2026-03-23 16:39:14 +01:00
NiccoloN 11916a2595 refactor Pim constant folding pass
share contiguous address resolution in PimCommon
group patterns in subdir for each pass with pattern files
2026-03-23 15:36:58 +01:00
NiccoloN 670d6ce94f extend operation support for conv and gemm
add more tests in validation
2026-03-23 14:46:08 +01:00
89 changed files with 2666 additions and 1141 deletions
@@ -0,0 +1,54 @@
name: Prepare MLIR Cache
description: Restore or build the cached MLIR toolchain.
inputs:
llvm-commit:
description: LLVM commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore MLIR cache
id: restore-mlir-cache
uses: actions/cache/restore@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Clone LLVM
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
shell: bash
run: |
cmake -S onnx-mlir/llvm-project/llvm -B onnx-mlir/llvm-project/build -G Ninja \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build onnx-mlir/llvm-project/build
- name: Save MLIR cache
if: steps.restore-mlir-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
@@ -0,0 +1,45 @@
name: Prepare Protobuf Cache
description: Restore or build the cached protobuf installation.
inputs:
protobuf-commit:
description: Protobuf tag or commit to build.
required: true
mold-linker-flags:
description: Linker flags used to force mold.
required: true
runs:
using: composite
steps:
- name: Restore protobuf cache
id: restore-protobuf-cache
uses: actions/cache/restore@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
- name: Install protobuf
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
shell: bash
run: |
git clone --depth 1 --branch ${{ inputs.protobuf-commit }} https://github.com/protocolbuffers/protobuf
cmake -S protobuf -B protobuf/build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_EXE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_SHARED_LINKER_FLAGS="${{ inputs.mold-linker-flags }}" \
-DCMAKE_MODULE_LINKER_FLAGS="${{ inputs.mold-linker-flags }}"
cmake --build protobuf/build
sudo cmake --install protobuf/build
rm -rf protobuf
- name: Save protobuf cache
if: steps.restore-protobuf-cache.outputs.cache-hit != 'true'
uses: actions/cache/save@v4
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ inputs.protobuf-commit }}
@@ -0,0 +1,26 @@
name: Restore Raptor Build Cache
description: Restore the cached raptor build directory for incremental builds.
inputs:
key:
description: Exact cache key to restore.
required: true
restore-keys:
description: Prefixes used to restore the most recent compatible cache.
required: false
outputs:
cache-hit:
description: Whether the exact cache key was restored.
value: ${{ steps.restore-raptor-build-cache.outputs.cache-hit }}
runs:
using: composite
steps:
- name: Restore raptor build cache
id: restore-raptor-build-cache
uses: actions/cache/restore@v4
with:
path: build
key: ${{ inputs.key }}
restore-keys: ${{ inputs.restore-keys }}
@@ -0,0 +1,16 @@
name: Save Raptor Build Cache
description: Save the raptor build directory after a successful incremental build.
inputs:
key:
description: Cache key used to save the build directory.
required: true
runs:
using: composite
steps:
- name: Save raptor build cache
uses: actions/cache/save@v4
with:
path: build
key: ${{ inputs.key }}
-50
View File
@@ -1,50 +0,0 @@
name: Build MLIR Cache
on:
workflow_call:
inputs:
llvm-commit:
required: true
type: string
jobs:
build-mlir:
runs-on: ubuntu-latest
steps:
- name: Cache MLIR build
id: cache-mlir
uses: actions/cache@v4
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ inputs.llvm-commit }}
- name: Install build dependencies
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
sudo apt update
sudo apt install -y cmake ninja-build build-essential
- name: Clone LLVM
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
git clone --filter=blob:none --no-checkout https://github.com/llvm/llvm-project.git onnx-mlir/llvm-project
cd onnx-mlir/llvm-project
git fetch --depth 1 origin ${{ inputs.llvm-commit }}
git checkout FETCH_HEAD
- name: Build MLIR
if: steps.cache-mlir.outputs.cache-hit != 'true'
run: |
mkdir -p onnx-mlir/llvm-project/build
cd onnx-mlir/llvm-project/build
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_ENABLE_RUNTIMES="openmp" \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON \
-DENABLE_LIBOMPTARGET=OFF \
-DLLVM_ENABLE_LIBEDIT=OFF
cmake --build .
+32 -46
View File
@@ -4,29 +4,14 @@ env:
LLVM_COMMIT: 0c2701fe7fa002e1befc5f86c268a7964f96d286
PROTOBUF_COMMIT: v34.0
CMAKE_VERSION: 4.3.0
MOLD_LINKER_FLAGS: -fuse-ld=mold
on:
push:
pull_request:
jobs:
# Expose env vars as outputs so they can be passed to reusable workflows
config:
runs-on: ubuntu-latest
outputs:
llvm-commit: ${{ steps.set-vars.outputs.llvm-commit }}
steps:
- id: set-vars
run: echo "llvm-commit=$LLVM_COMMIT" >> "$GITHUB_OUTPUT"
build-mlir-cache:
needs: config
uses: ./.github/workflows/build_mlir_cache.yml
with:
llvm-commit: ${{ needs.config.outputs.llvm-commit }}
validate:
needs: build-mlir-cache
validate-operations:
runs-on: ubuntu-latest
steps:
@@ -39,7 +24,7 @@ jobs:
- name: Install system dependencies
run: |
sudo apt update
sudo apt install -y ninja-build build-essential curl ca-certificates
sudo apt install -y cmake ninja-build build-essential mold curl ca-certificates
- name: Install CMake
run: |
@@ -58,27 +43,17 @@ jobs:
cmake --version
which cmake
- name: Cache protobuf build
id: cache-protobuf
uses: actions/cache@v4
- name: Prepare MLIR cache
uses: ./.github/actions/prepare-mlir-cache
with:
path: |
/usr/local/lib/libproto*
/usr/local/include/google/protobuf
key: protobuf-${{ runner.os }}-${{ env.PROTOBUF_COMMIT }}
llvm-commit: ${{ env.LLVM_COMMIT }}
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
- name: Install protobuf
if: steps.cache-protobuf.outputs.cache-hit != 'true'
run: |
git clone --depth 1 --branch ${{ env.PROTOBUF_COMMIT }} https://github.com/protocolbuffers/protobuf
cd protobuf
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
ninja
sudo ninja install
cd ../..
rm -rf protobuf
- name: Prepare protobuf cache
uses: ./.github/actions/prepare-protobuf-cache
with:
protobuf-commit: ${{ env.PROTOBUF_COMMIT }}
mold-linker-flags: ${{ env.MOLD_LINKER_FLAGS }}
- name: Register installed libraries
run: sudo ldconfig
@@ -94,23 +69,34 @@ jobs:
- name: Install Python dependencies
run: pip install numpy onnx colorama
- name: Restore MLIR cache
uses: actions/cache/restore@v4
- name: Restore raptor build cache
id: restore-raptor-build-cache
uses: ./.github/actions/restore-raptor-build-cache
with:
path: onnx-mlir/llvm-project
key: mlir-${{ runner.os }}-${{ env.LLVM_COMMIT }}
fail-on-cache-miss: true
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
restore-keys: |
raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
raptor-build-${{ runner.os }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-
- name: Build Raptor
id: build-raptor
run: |
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir -p build && cd build
cmake .. -G Ninja \
cmake -S . -B build -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
cmake --build .
-DMLIR_DIR=${MLIR_DIR} \
-DCMAKE_EXE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_SHARED_LINKER_FLAGS="${MOLD_LINKER_FLAGS}" \
-DCMAKE_MODULE_LINKER_FLAGS="${MOLD_LINKER_FLAGS}"
cmake --build build
- name: Save raptor build cache
if: steps.build-raptor.outcome == 'success' && steps.restore-raptor-build-cache.outputs.cache-hit != 'true'
uses: ./.github/actions/save-raptor-build-cache
with:
key: raptor-build-${{ runner.os }}-${{ github.ref_name }}-${{ env.LLVM_COMMIT }}-${{ env.PROTOBUF_COMMIT }}-${{ env.CMAKE_VERSION }}-${{ github.sha }}
- name: Run validation
run: |
+9 -1
View File
@@ -29,7 +29,7 @@ Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it with:
to reduce memory usage during linking. You can use it by setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
@@ -38,8 +38,16 @@ to reduce memory usage during linking. You can use it with:
### Raptor
Use the following commands to build Raptor.
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold"
```
```
git submodule update --init --recursive
@@ -530,6 +530,7 @@ where
let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0];
@@ -224,7 +224,21 @@ fn json_to_vvsub(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vvsub", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvsub, inst_data_builder.build());
Ok(())
}
@@ -256,7 +270,21 @@ fn json_to_vvdmul(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vvdmul", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
Ok(())
}
@@ -306,7 +334,21 @@ fn json_to_vavg(
inst_data_builder: &mut InstructionDataBuilder,
json: &Value,
) -> Result<()> {
todo!("Not present in the compiler");
let json = json.as_object().expect("Not an object");
assert_eq!("vavg", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let rs2 = json_i64!(json, "rs2") as i32;
let len = json_i64!(json, "len") as i32;
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_r1(rs1)
.set_r2(rs2)
.set_imm_len(len)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
inst_builder.make_inst(vavg, inst_data_builder.build());
Ok(())
}
@@ -358,7 +400,7 @@ fn json_to_vsigm(
json: &Value,
) -> Result<()> {
let json = json.as_object().expect("Not an object");
assert_eq!("vsigmoid", json_str!(json, "op"));
assert_eq!("vsigm", json_str!(json, "op"));
let rd = json_i64!(json, "rd") as i32;
let rs1 = json_i64!(json, "rs1") as i32;
let len = json_i64!(json, "len") as i32;
+38 -17
View File
@@ -10,34 +10,54 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
add_subdirectory(Common)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp
Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
Pass/PimConstantFoldingPass.cpp
Pass/PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
)
set(PIM_COMPILER_INCLUDE_DIRS
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_ACCEL_INCLUDE_DIRS
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
)
set(PIM_GENERATED_INCLUDE_DIRS
${PIM_INCLUDE_PATH}
)
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
endfunction()
add_subdirectory(Dialect)
add_subdirectory(Common)
add_subdirectory(Pass)
add_subdirectory(Compiler)
add_subdirectory(Conversion)
add_pim_library(OMPIMAccel
PimAccelerator.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
onnx
OMAccelerator
OMPimCompilerUtils
OMCompilerUtils
OMPimPasses
OMONNXOps
SpatialOps
PimOps
@@ -45,5 +65,6 @@ add_onnx_mlir_library(OMPIMAccel
OMSpatialToGraphviz
OMSpatialToPim
OMPimCommon
OMPimBufferization
MLIRTensorInferTypeOpInterfaceImpl
)
+3 -9
View File
@@ -1,19 +1,13 @@
add_onnx_mlir_library(OMPimCommon
add_pim_library(OMPimCommon
PimCommon.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
onnx
OMPimCompilerUtils
SpatialOps
PimOps
)
)
+63
View File
@@ -1,3 +1,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
return true;
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
int64_t byteOffset = 0;
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress{value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = tiedOperand->get();
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|| llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress{value, byteOffset};
return failure();
}
}
} // namespace onnx_mlir
+7
View File
@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
} // namespace onnx_mlir
-44
View File
@@ -1,44 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};
+12 -19
View File
@@ -1,44 +1,37 @@
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)
add_onnx_mlir_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerOptions
PimCompilerOptions.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_COMPILER_INCLUDE_DIRS}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerOptions
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_ACCEL_INCLUDE_DIRS}
)
add_onnx_mlir_library(OMPimCompilerUtils
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimCodeGen.cpp
../Pass/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PRIVATE
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_COMPILER_INCLUDE_DIRS}
LINK_LIBS PUBLIC
${OMLibs}
OMCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
OMCompilerPasses
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_BIN_ROOT}
${PIM_ACCEL_INCLUDE_DIRS}
)
+142 -62
View File
@@ -14,12 +14,11 @@
#include <cmath>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace llvm;
using namespace mlir;
@@ -86,48 +85,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
size_t offset = 0;
while (true) {
auto definingOp = value.getDefiningOp();
if (!definingOp)
break;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
break;
value = tiedOperand->get();
}
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto source = subviewDefiningOp.getSource();
auto srcShape = source.getType().getShape();
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
size_t localOffset = subviewOffsets[i];
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
localOffset *= subviewSizes[j];
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
}
value = source;
}
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
}
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
}
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
}
else
break;
}
auto iter = memEntriesMap.find(value);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
@@ -135,10 +95,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Failed to resolve contiguous address");
}
auto iter = memEntriesMap.find(resolvedAddress->base);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + offset;
return iter->second.address + resolvedAddress->byteOffset;
}
json::Object PimCodeGen::createEmptyOffset() {
@@ -264,14 +237,16 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeM
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
auto outBufAddr = memory.getValueAddress(vaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vaddOp.getA());
auto bAddr = memory.getValueAddress(vaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
auto outputType = cast<MemRefType>(vaddOp.getOutBuf().getType());
size_t totalBytes = outputType.getNumElements() * vaddOp.getOutRes().getType().getElementTypeBitWidth() / 8;
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outBufAddr = memory.getValueAddress(vvaddOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvaddOp.getA());
auto bAddr = memory.getValueAddress(vvaddOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvadd";
@@ -279,14 +254,46 @@ void PimCodeGen::codeGenVAddOp(pim::PimVAddOp vaddOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = totalBytes;
json["len"] = getValueSizeInBytes(vvaddOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
auto outBufAddr = memory.getValueAddress(vmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vmaxOp.getA());
auto bAddr = memory.getValueAddress(vmaxOp.getB());
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outBufAddr = memory.getValueAddress(vvsubOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvsubOp.getA());
auto bAddr = memory.getValueAddress(vvsubOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvsub";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outBufAddr = memory.getValueAddress(vvmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmulOp.getA());
auto bAddr = memory.getValueAddress(vvmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outBufAddr = memory.getValueAddress(vvmaxOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvmaxOp.getA());
auto bAddr = memory.getValueAddress(vvmaxOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
@@ -295,6 +302,37 @@ void PimCodeGen::codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const {
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outBufAddr = memory.getValueAddress(vvdmulOp.getOutBuf());
auto aAddr = memory.getValueAddress(vvdmulOp.getA());
auto bAddr = memory.getValueAddress(vvdmulOp.getB());
setupRdRs1Rs2(outBufAddr, 0, aAddr, 0, bAddr, 0);
json::Object json;
json["op"] = "vvdmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outBufAddr = memory.getValueAddress(vavgOp.getOutBuf());
auto aAddr = memory.getValueAddress(vavgOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getA());
emitInstruction(std::move(json));
}
@@ -308,6 +346,35 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outBufAddr = memory.getValueAddress(vtanhOp.getOutBuf());
auto aAddr = memory.getValueAddress(vtanhOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getA());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outBufAddr = memory.getValueAddress(vsigmOp.getOutBuf());
auto aAddr = memory.getValueAddress(vsigmOp.getA());
setupRdRs1(outBufAddr, 0, aAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getA());
emitInstruction(std::move(json));
}
@@ -365,6 +432,7 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
vaddJson["rs1"] = 1;
vaddJson["rs2"] = 2;
vaddJson["offset"] = createEmptyOffset();
vaddJson["len"] = 32 * outChannels;
emitInstruction(std::move(vaddJson));
}
}
@@ -506,13 +574,25 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
coreCodeGen.codeGenVAddOp(vaddOp);
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
coreCodeGen.codeGenVMaxOp(vmaxOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp);
else if (isa<pim::PimSumOp, pim::PimVSDivOp, pim::PimVExpOp>(op)) {
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (isa<pim::PimSumOp>(op)) {
// TODO: Implement somehow?
op.emitWarning("Operation is not yet supported in code generation");
continue;
+9 -4
View File
@@ -3,8 +3,7 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/Support/JSON.h"
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -91,9 +90,15 @@ public:
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
};
+3 -3
View File
@@ -5,7 +5,6 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#define DEBUG_TYPE "PimCompilerUtils"
@@ -48,8 +47,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimHostVerificationPass());
pm.addPass(createMessagePass("Pim host verified"));
pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));
@@ -2,23 +2,27 @@ set(LLVM_TARGET_DEFINITIONS ONNXToSpatial.td)
mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.cpp
Math/Conv.cpp
NN/Pooling.cpp
NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp
Tensor/RemoveUnusedHelperOps.cpp
add_pim_library(OMONNXToSpatial
Patterns/Math/Gemm.cpp
Patterns/Math/Conv.cpp
Patterns/Math/MatMul.cpp
Patterns/NN/Pool.cpp
Patterns/NN/ReduceMean.cpp
Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp
ONNXToSpatialPass.cpp
ONNXToSpatialCommon.cpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPimCompilerOptions
OMONNXOps
@@ -26,5 +30,5 @@ add_onnx_mlir_library(OMONNXToSpatial
OMPimCommon
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)
@@ -15,7 +15,7 @@
#include <optional>
#include <utility>
#include "ONNXToSpatialCommon.hpp"
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,427 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1)
return valuesToReduce[0];
if (preprocess) {
for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
}
// It is possible that `valuesToReduce` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& valToReduce : valuesToReduce) {
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
// } else {
// // Otherwise it is a block argument,
// computeOp->getBlock()->getParentOp();
// }
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
auto it = lastValueForCompute.find(computeOp);
if (it != lastValueForCompute.end()) {
// If we have already seen this computeOp, apply the reduction
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
else
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
lastValueForCompute[computeOp] = valToReduce;
}
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second);
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
// Recursive algorithm to reduce the inputs to a single one:
// - Take two inputs at a time, and reduce them into a single one, updating
// the valuesToReduce list which becomes half the size.
// - Repeat until there is only one input left.
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
while (valuesToReduceRef.size() > 1) {
SmallVector<Value> nextValuesToReduce;
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
auto firstValue = valuesToReduceRef[i];
auto secondValue = valuesToReduceRef[i + 1];
auto firstCompute = firstValue.getParentBlock()->getParentOp();
auto secondCompute = secondValue.getParentBlock()->getParentOp();
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
if (secondCompute->isBeforeInBlock(firstCompute)) {
std::swap(firstValue, secondValue);
std::swap(firstCompute, secondCompute);
}
// 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute);
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue);
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
Value reduced = reduce(receivedValue, secondValue);
nextValuesToReduce.push_back(reduced);
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back());
// Replace the inputOps list with the new one.
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
if (postprocess) {
rewriter.setInsertionPointAfterValue(finalValue);
finalValue = postprocess(finalValue);
}
return finalValue;
}
template <typename PoolOp>
bool hasPostProcessPoolingWindow() {
return false;
}
template <>
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
return true;
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
Location loc = poolOp.getLoc();
size_t input_h = getImageHeight(xShape);
size_t input_w = getImageWidth(xShape);
size_t output_h = getImageHeight(yShape);
size_t output_w = getImageWidth(yShape);
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
// 1: Tile the input tensor
// Input tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt =
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
// If some input tiles come from the func.func operands, load
// them into a computeOp and yield them
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
tileLoc,
extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0);
}
}
}
}
// 2: Tile the output tensor
// Output tiles need to be indexed by:
// a. Channel Tile
// b. Pixel `x` position
// c. Pixel `y` position
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
// Each output pixel tile is computed by pooling a window of input
// pixel tiles
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
}
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
}
else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool,
rewriter,
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
nullptr,
postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced =
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel =
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
}
}
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
// Iterate each output pixel
for (size_t outX = 0; outX < output_w; outX++) {
for (size_t outY = 0; outY < output_h; outY++) {
auto outputTile = outputTiles[outTile][outX][outY];
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
return success();
}
};
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
// ONNXMatMulOp to ONNXGemmOp patterns
def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C,
@@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
def matMulToGemmPattern : Pat<
(ONNXMatMulOp:$matmulres $A, $B),
(
ONNXGemmOp $A, $B,
ONNXGemmOp $A, $B,
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
@@ -1,8 +1,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
@@ -12,7 +13,7 @@
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -56,6 +57,7 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
@@ -74,7 +76,9 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
});
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
@@ -83,13 +87,15 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXReshapeOp>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
populateConvOpPatterns(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);
@@ -113,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
}
}
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);
@@ -7,13 +7,15 @@ namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB) {
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
gemmC = tensor::ExpandShapeOp::create(rewriter,
loc,
biasType,
b,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
if (hasB)
gemmC = b;
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
@@ -11,7 +11,7 @@
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -23,6 +23,38 @@ namespace {
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
static FailureOr<Value> materializeScaledConstantTensor(Value value,
float factor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (factor == 1.0f)
return value;
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
if (numOutRows <= 1)
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
aSlice,
b,
cSlice,
gemmOp.getAlphaAttr(),
gemmOp.getBetaAttr(),
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
}
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
}
if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
alpha = 1.0f;
}
if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
beta = 1.0f;
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure();
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
auto* concatBlock = new Block();
for (Value gemmRow : gemmRows)
concatBlock->addArgument(gemmRow.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
Value result = ONNXTransposeOp::create(
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();
}
};
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
}
template <typename PoolOp>
struct PoolToSpatialCompute;
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
using OpConversionPattern<PoolOp>::OpConversionPattern;
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = poolOp.getLoc();
Value x = adaptor.getX();
auto xType = dyn_cast<RankedTensorType>(x.getType());
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
if (xType.getRank() != 4 || outType.getRank() != 4)
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
ArrayAttr kernelAttr = poolOp.getKernelShape();
if (!kernelAttr || kernelAttr.size() != 2)
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
const int64_t batchSize = xType.getDimSize(0);
const int64_t channels = xType.getDimSize(1);
const int64_t inputHeight = xType.getDimSize(2);
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
int64_t padBottom = 0;
int64_t padRight = 0;
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
if (autoPad == "SAME_UPPER") {
padTop = totalPadH / 2;
padBottom = totalPadH - padTop;
padLeft = totalPadW / 2;
padRight = totalPadW - padLeft;
}
else {
padBottom = totalPadH / 2;
padTop = totalPadH - padBottom;
padRight = totalPadW / 2;
padLeft = totalPadW - padRight;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
auto* computeBlock = new Block();
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.replaceOp(poolOp, computeOp.getResult(0));
return success();
}
};
template <>
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
template <>
struct PoolToSpatialCompute<ONNXAveragePoolOp>
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
};
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}
} // namespace onnx_mlir
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -1,15 +1,15 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext* ctx)
struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
};
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx);
patterns.insert<Concat>(ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,121 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(sourceIdx);
while (sourceProduct != resultProduct) {
if (sourceProduct > resultProduct)
return false;
sourceIdx++;
if (sourceIdx >= sourceShape.size())
return false;
group.push_back(sourceIdx);
sourceProduct *= sourceShape[sourceIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(resultIdx);
while (resultProduct != sourceProduct) {
if (resultProduct > sourceProduct)
return false;
resultIdx++;
if (resultIdx >= resultShape.size())
return false;
group.push_back(resultIdx);
resultProduct *= resultShape[resultIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
ONNXReshapeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
return failure();
if (sourceType == resultType) {
rewriter.replaceOp(reshapeOp, adaptor.getData());
return success();
}
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
return failure();
}
};
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Reshape>(ctx);
}
} // namespace onnx_mlir
@@ -1,35 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir
@@ -1,7 +1,7 @@
#include <queue>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -7,7 +7,7 @@
#include <unordered_map>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
@@ -1,14 +1,17 @@
add_onnx_mlir_rewriter(SpatialToGraphviz)
add_onnx_mlir_library(OMSpatialToGraphviz
add_pim_library(OMSpatialToGraphviz
SpatialToGraphviz.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
OMONNXOps
SpatialOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)
@@ -2,19 +2,22 @@ set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_onnx_mlir_library(OMSpatialToPim
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
SpatialToPimCommon.cpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)
@@ -5,7 +5,7 @@
#include <cassert>
#include <cstddef>
#include "SpatialToPimCommon.hpp"
#include "Common.hpp"
using namespace llvm;
using namespace mlir;
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVAddOp : Pat<
def spatToPimVVAddOp : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVAddOp $a, $b,
(PimVVAddOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMulOp : Pat<
(SpatVMulOp:$srcOpRes $a, $b),
(PimVVMulOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVMaxOp : Pat<
(SpatVMaxOp:$srcOpRes $a, $b),
(PimVVMaxOp $a, $b,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -18,9 +19,9 @@
#include <string>
#include <utility>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -56,10 +57,8 @@ private:
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value channelSourceOp,
spatial::SpatChannelNewOp& channel,
bool useBroadcastOp,
IRRewriter& rewriter);
void
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
@@ -79,8 +78,31 @@ private:
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
op);
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static size_t countComputeLeafUsers(Value value) {
@@ -175,35 +197,67 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty())
continue;
auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isChannelUseChainOp(resultUser)) {
SmallVector<Operation*> returnChain;
Value chainedValue = result;
Operation* chainUser = resultUser;
while (isChannelUseChainOp(chainUser)) {
returnChain.push_back(chainUser);
auto chainUses = chainUser->getResult(0).getUses();
if (rangeLength(chainUses) != 1)
break;
chainedValue = chainUser->getResult(0);
chainUser = chainUses.begin()->getOwner();
}
if (isa<func::ReturnOp>(chainUser)) {
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
rewriter.setInsertionPoint(yieldOp);
IRMapping mapping;
mapping.map(result, yieldValue);
Value storedValue = yieldValue;
for (Operation* op : returnChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
storedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
markOpToRemove(op);
}
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
storedValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
continue;
}
}
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0;
@@ -475,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) {
@@ -493,6 +548,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -502,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
replacementValue = cast<Value>(mapping.lookup(consumerValue));
}
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
assert(replacementValue.getType() == blockArg.getType()
&& "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue);
}
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
namespace spatial {
// TODO: Add here eventual patterns
}
} // namespace onnx_mlir
+3 -1
View File
@@ -3,10 +3,12 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_onnx_mlir_library(PimOps
add_pim_library(PimOps
PimOps.hpp
PimOps.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
OMPimIncGen
+101 -9
View File
@@ -251,7 +251,7 @@ def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> {
}];
}
def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
def PimVVAddOp: PimOp<"vvadd", [DestinationStyleOpInterface]> {
let description = [{
Element-wise addition: c = a + b
}];
@@ -277,7 +277,59 @@ def PimVAddOp: PimOp<"vadd", [DestinationStyleOpInterface]> {
}];
}
def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVVSubOp: PimOp<"vvsub", [DestinationStyleOpInterface]> {
let description = [{
Element-wise subtraction: c = a - b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMulOp: PimOp<"vvmul", [DestinationStyleOpInterface]> {
let description = [{
Element-wise multiplication: c = a * b
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVMaxOp: PimOp<"vvmax", [DestinationStyleOpInterface]> {
let description = [{
Element-wise max: c = max(a, b)
}];
@@ -291,6 +343,32 @@ def PimVMaxOp: PimOp<"vmax", [DeclareOpInterfaceMethods<BufferViewFlowOpInterfac
let results = (outs
PimTensor: $outRes
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutBufMutable();
}
}];
let assemblyFormat = [{
`(` $a `,` $b `,` $outBuf `)` attr-dict `:` `(` type($a) `,` type($b) `,` type($outBuf) `)` `->` type($outRes)
}];
}
def PimVVDMulOp: PimOp<"vvdmul", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Dot product: c = dot(a, b)
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $b,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimApplyFiltersOp: PimOp<"apply_filters", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
@@ -332,14 +410,13 @@ def PimSumOp: PimOp<"sum", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>
);
}
def PimVSDivOp: PimOp<"vsdiv", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVAvgOp: PimOp<"vavg", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise division between each element of a vector, and a scalar (wrapped in a tensor for convenience)
Average all elements into a single one
}];
let arguments = (ins
PimTensor: $dividend,
PimTensor: $divisor,
PimTensor: $a,
PimTensor: $outBuf
);
@@ -363,9 +440,24 @@ def PimVReluOp: PimOp<"vrelu", [DeclareOpInterfaceMethods<BufferViewFlowOpInterf
);
}
def PimVExpOp: PimOp<"vexp", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
def PimVTanhOp: PimOp<"vtanh", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise exp: c = exp(a)
Element-wise tanh activation
}];
let arguments = (ins
PimTensor: $a,
PimTensor: $outBuf
);
let results = (outs
PimTensor: $outRes
);
}
def PimVSigmOp: PimOp<"vsigm", [DeclareOpInterfaceMethods<BufferViewFlowOpInterface>]> {
let description = [{
Element-wise sigmoid activation
}];
let arguments = (ins
@@ -388,4 +480,4 @@ def PimHaltOp: PimOp<"halt", [Terminator]> {
}];
}
#endif // PIM_DIALECT_H
#endif // PIM_DIALECT_H
+4 -3
View File
@@ -30,12 +30,13 @@ void PimDialect::initialize() {
registerDependenciesFn(this->getOutBuf(), this->getResult()); \
}
POPULATE_DEPENDENCIES(PimVMaxOp)
POPULATE_DEPENDENCIES(PimVVDMulOp)
POPULATE_DEPENDENCIES(PimApplyFiltersOp)
POPULATE_DEPENDENCIES(PimSumOp)
POPULATE_DEPENDENCIES(PimVSDivOp)
POPULATE_DEPENDENCIES(PimVAvgOp)
POPULATE_DEPENDENCIES(PimVReluOp)
POPULATE_DEPENDENCIES(PimVExpOp)
POPULATE_DEPENDENCIES(PimVTanhOp)
POPULATE_DEPENDENCIES(PimVSigmOp)
} // namespace pim
} // namespace onnx_mlir
@@ -2,13 +2,15 @@ set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(PimBufferizationIncGen)
add_onnx_mlir_library(OMPimBufferization
add_pim_library(OMPimBufferization
PimBufferizationPass.cpp
OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp
Common.hpp
Common.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
PimBufferizationIncGen
@@ -17,5 +19,5 @@ add_onnx_mlir_library(OMPimBufferization
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
${PIM_GENERATED_INCLUDE_DIRS}
)
@@ -4,6 +4,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -12,6 +13,26 @@ using namespace bufferization;
namespace onnx_mlir {
namespace pim {
static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op,
@@ -164,7 +185,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
}
};
struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOpBufferizeInterface, PimVAddOp> {
template <typename OpTy>
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
@@ -179,21 +201,24 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto binaryOp = cast<OpTy>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
auto aOpt = getBuffer(rewriter, binaryOp.getA(), options, state);
if (failed(aOpt))
return failure();
auto bOpt = getBuffer(rewriter, vaddOp.getB(), options, state);
auto bOpt = getBuffer(rewriter, binaryOp.getB(), options, state);
if (failed(bOpt))
return failure();
auto outBufOpt = getBuffer(rewriter, vaddOp.getOutBuf(), options, state);
auto outBufOpt = getBuffer(rewriter, binaryOp.getOutBuf(), options, state);
if (failed(outBufOpt))
return failure();
replaceOpWithNewBufferizedOp<PimVAddOp>(rewriter, op, outBufOpt->getType(), *aOpt, *bOpt, *outBufOpt);
Value contiguousA = materializeContiguousMemRef(*aOpt, op->getLoc(), rewriter);
Value contiguousB = materializeContiguousMemRef(*bOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outBufOpt->getType(), contiguousA, contiguousB, *outBufOpt);
return success();
}
};
@@ -205,7 +230,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimTransposeOp::attachInterface<TransposeOpBufferizeInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpBufferizeInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpBufferizeInterface>(*ctx);
PimVAddOp::attachInterface<VAddOpBufferizeInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpBufferizeInterface<PimVVMaxOp>>(*ctx);
});
}
+9 -3
View File
@@ -1,16 +1,22 @@
add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_onnx_mlir_library(SpatialOps
add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
EXCLUDE_FROM_OM_LIBS
DEPENDS
OMONNXIncGen
OMSpatialIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRBufferizationDialect
MLIRBufferizationTransforms
OMMlirDialects
OMONNXOps
OMPimCompilerOptions
PimOps
)
+1 -1
View File
@@ -24,7 +24,7 @@
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -37,8 +37,76 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
return memref::AllocOp::create(rewriter, loc, memrefResultType);
}
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getDstOut();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
++usersIterator;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
++usersIterator;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
return failure();
}
Operation* otherUser = nullptr;
if (firstUser == op)
otherUser = secondUser;
else if (secondUser == op)
otherUser = firstUser;
else {
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
return failure();
}
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
return failure();
}
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
return failure();
}
return otherUser;
}
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
@@ -49,7 +117,7 @@ llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsR
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive, rewriter);
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
@@ -119,7 +187,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(*memref);
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
}
// TODO: Support addiction with more than 2 operands
@@ -412,7 +480,7 @@ struct ChannelBroadcastSendOpInterface
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVAddOp> {};
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
@@ -420,9 +488,7 @@ struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, Spa
struct SumOpInterface : VariadicArgumentElementWiseOpInterface<SumOpInterface, SpatSumOp, pim::PimSumOp> {};
struct VSDivOpInterface : VariadicArgumentElementWiseOpInterface<VSDivOpInterface, SpatVSDivOp, pim::PimVSDivOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVMaxOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
// Create a new bufferizable op interface for the apply filters operation.
struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFiltersOpInterface, SpatApplyFiltersOp> {
@@ -509,7 +575,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatSumOp::attachInterface<SumOpInterface>(*ctx);
SpatVSDivOp::attachInterface<VSDivOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
@@ -521,12 +586,16 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXExpOpInterface : VariadicArgumentElementWiseOpInterface<ONNXExpOpInterface, ONNXExpOp, pim::PimVExpOp> {};
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXExpOp::attachInterface<ONNXExpOpInterface>(*ctx);
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
});
}
+16
View File
@@ -0,0 +1,16 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
PimConstantFolding/Common.cpp
PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp
PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRLinalgDialect
OMPimCommon
)
-1
View File
@@ -3,7 +3,6 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir;
+6 -2
View File
@@ -1,6 +1,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Compiler/CompilerUtils.hpp"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -17,7 +18,10 @@ struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
: message(message) {}
MessagePass(const MessagePass& pass) {}
void runOnOperation() final { showCompilePhase(message); }
void runOnOperation() final {
llvm::outs() << message << "\n";
llvm::outs().flush();
}
private:
std::string message;
+121
View File
@@ -0,0 +1,121 @@
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
} // namespace onnx_mlir
@@ -0,0 +1,41 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
mlir::Location loc,
mlir::MemRefType globalType,
mlir::DenseElementsAttr denseAttr,
llvm::StringRef nameStem,
mlir::IntegerAttr alignment = {});
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
llvm::ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth);
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,19 +1,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -21,55 +13,14 @@
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment = {}) {
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
@@ -126,13 +77,6 @@ static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr den
return DenseElementsAttr::get(transposedType, transposedValues);
}
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty())
return failure();
@@ -176,151 +120,13 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
rewriter.eraseOp(mapOp);
return success();
}
};
struct StaticSubviewInfo {
Value source;
SmallVector<int64_t> sourceShape;
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
};
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
static int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (copyOp.getSize() != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = copyOp.getSrcOffset()
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = copyOp.getDstOffset()
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : copyOp.getDst(),
splitSrc ? srcSubview->source : copyOp.getSrc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
@@ -473,17 +279,15 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
MemRefType globalType = resultType;
auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(),
globalType,
resultType,
*transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
@@ -578,41 +382,106 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
}
};
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];
}
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
dumpModule(getOperation(), "pim2_folded");
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantMemCpPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,223 @@
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
dstByteOffset,
srcByteOffset,
sliceBytes);
}
return success();
}
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDst(),
copyOp.getSrc(),
copyOp.getDstOffset(),
copyOp.getSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDst());
return success();
}
};
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto status =
rewriteSubviewCopyLikeOp(copyOp,
copyOp.getDeviceDst(),
copyOp.getHostSrc(),
copyOp.getDeviceDstOffset(),
copyOp.getHostSrcOffset(),
copyOp.getSize(),
rewriter,
[&](MemRefType resultType,
Value dst,
Value src,
int64_t dstByteOffset,
int64_t srcByteOffset,
int64_t sliceBytes) {
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
return success();
}
};
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
if (subviewOp.use_empty())
return failure();
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
return failure();
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
if (failed(denseAttr))
return failure();
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
if (failed(subviewInfo))
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
auto resultMemRefType =
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
return success();
}
};
} // namespace
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,53 @@
#include "Patterns.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
populateConstantFoldingConstantPatterns(owningPatterns);
populateConstantFoldingSubviewPatterns(owningPatterns);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim2_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
} // namespace onnx_mlir
@@ -0,0 +1,135 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
if (isa<pim::PimHaltOp>(op))
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
}
if (hasFailure) {
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
}
} // namespace onnx_mlir
+3 -1
View File
@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
@@ -26,49 +27,32 @@ static bool isAddressOnlyHostOp(Operation* op) {
spatial::SpatChannelNewOp>(op);
}
static bool isHostAddressableValue(Value value) {
while (true) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return true;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
static bool isCodegenAddressableValue(Value value) {
auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress))
return false;
}
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
StringRef getArgument() const override { return "verify-pim-host-pass"; }
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR";
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
PimHostVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
PimVerificationPass() {}
PimVerificationPass(const PimVerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
@@ -80,7 +64,7 @@ struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationP
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)))
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
hasFailure = true;
continue;
}
@@ -139,14 +123,49 @@ private:
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
if (!isHostAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
if (!isCodegenAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (Operation& op : coreOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op))
continue;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
@@ -160,16 +179,16 @@ private:
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
if (isHostAddressableValue(source))
if (isCodegenAddressableValue(source))
return success();
op->emitOpError("depends on a value that still requires host-side execution");
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
} // namespace onnx_mlir
+2 -1
View File
@@ -74,7 +74,8 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimHostVerificationPass);
registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass);
}
+59
View File
@@ -0,0 +1,59 @@
# Validation Operations
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
Generated tests can be regenerated with:
```
python3 validation/operations/gen_tests.py
```
## Conv
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|------|-----------|-------|--------|--------|--------|---------|------|-------|
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
## Pool
| Test | Directory | Input | Output | Kernel | Stride | Padding | Notes |
|------|-----------|-------|--------|--------|--------|---------|-------|
| Max basic | `pool/max_basic` | [1,1,4,4] | [1,1,3,3] | 2x2 | 1 | none | Basic max pooling |
| Max stride 2 multi-channel | `pool/max_stride2_multichannel` | [1,5,6,6] | [1,5,3,3] | 2x2 | 2 | none | Channel-preserving max pool |
| Max SAME_UPPER | `pool/max_same_upper` | [1,1,5,5] | [1,1,3,3] | 3x3 | 2 | SAME_UPPER | Deprecated auto_pad path |
| Avg basic | `pool/avg_basic` | [1,3,4,4] | [1,3,3,3] | 2x2 | 1 | none | Basic average pooling |
| Avg explicit padding | `pool/avg_explicit_padding` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=0` |
| Avg include pad | `pool/avg_include_pad` | [1,2,4,4] | [1,2,2,2] | 3x3 | 2 | [1,1,1,1] | `count_include_pad=1` |
| Max after Conv | `pool/max_after_conv` | [1,3,6,6] | [1,4,2,2] | Conv 3x3 then Pool 2x2 | 2 | none | Regression for `pool(conv(...))` |
## Gemm
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
## Gemv
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|------|-----------|-------|------------|--------|------|-------|
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+364
View File
@@ -0,0 +1,364 @@
#!/usr/bin/env python3
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
from pathlib import Path
OPERATIONS_DIR = Path(__file__).parent
def save_model(model, directory, filename):
"""Save an ONNX model, creating the directory if needed."""
d = OPERATIONS_DIR / directory
d.mkdir(parents=True, exist_ok=True)
path = d / filename
onnx.checker.check_model(model)
onnx.save(model, str(path))
print(f" {path.relative_to(OPERATIONS_DIR)}")
# ---------------------------------------------------------------------------
# GEMM tests
# ---------------------------------------------------------------------------
def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
def gemm_with_bias():
"""GEMM with bias: Y = A @ W + C."""
B, K, N = 4, 128, 128
rng = np.random.default_rng(43)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
def gemm_transB():
"""GEMM with transB=1: Y = A @ W^T."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(44)
# W stored as [N, K], transposed during computation
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB", "gemm_transB.onnx")
def gemm_alpha_beta():
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
B, K, N = 4, 64, 64
rng = np.random.default_rng(45)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
def gemm_small():
"""Small GEMM: [2, 8] @ [8, 4]."""
B, K, N = 2, 8, 4
rng = np.random.default_rng(46)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/small", "gemm_small.onnx")
def gemm_large():
"""Larger GEMM: [8, 256] @ [256, 128]."""
B, K, N = 8, 256, 128
rng = np.random.default_rng(47)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/large", "gemm_large.onnx")
def gemm_transB_with_bias():
"""GEMM with transB and bias: Y = A @ W^T + C."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(48)
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
# ---------------------------------------------------------------------------
# Conv tests
# ---------------------------------------------------------------------------
def conv_3x3_kernel():
"""Conv with 3x3 kernel, no padding."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
def conv_stride2():
"""Conv with 3x3 kernel and stride 2."""
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
W = numpy_helper.from_array(
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
def conv_multi_channel():
"""Conv with multiple input and output channels."""
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
def conv_1x1():
"""1x1 pointwise convolution (channel mixing)."""
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
def conv_same_padding_3x3():
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
W = numpy_helper.from_array(
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
def conv_explicit_padding():
"""Conv 3x3 with explicit asymmetric padding."""
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
def conv_with_bias_3x3():
"""Conv 3x3 with bias."""
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
rng = np.random.default_rng(56)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
def conv_batch_2():
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
rng = np.random.default_rng(57)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
def conv_large_spatial():
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
W = numpy_helper.from_array(
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
# ---------------------------------------------------------------------------
# Pooling tests
# ---------------------------------------------------------------------------
def maxpool_basic():
"""MaxPool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_basic", "maxpool_basic.onnx")
def maxpool_stride2_multichannel():
"""MaxPool 2x2 with stride 2 on multiple channels."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 5, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 5, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "maxpool_stride2_multichannel", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_stride2_multichannel", "maxpool_stride2_multichannel.onnx")
def maxpool_same_upper():
"""MaxPool 3x3 with SAME_UPPER padding."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
node = helper.make_node("MaxPool", ["X"], ["Y"], kernel_shape=[3, 3], strides=[2, 2], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "maxpool_same_upper", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_same_upper", "maxpool_same_upper.onnx")
def avgpool_basic():
"""AveragePool 2x2 with stride 1."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 3, 3])
node = helper.make_node("AveragePool", ["X"], ["Y"], kernel_shape=[2, 2], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "avgpool_basic", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_basic", "avgpool_basic.onnx")
def avgpool_explicit_padding():
"""AveragePool 3x3 with explicit padding, excluding pad from the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=0)
graph = helper.make_graph([node], "avgpool_explicit_padding", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_explicit_padding", "avgpool_explicit_padding.onnx")
def avgpool_include_pad():
"""AveragePool 3x3 with explicit padding, including pad in the divisor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 2, 2])
node = helper.make_node("AveragePool", ["X"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[1, 1, 1, 1], count_include_pad=1)
graph = helper.make_graph([node], "avgpool_include_pad", [X], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/avg_include_pad", "avgpool_include_pad.onnx")
def maxpool_after_conv():
"""Conv followed by MaxPool to validate pooling on lowered conv results."""
rng = np.random.default_rng(59)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 2, 2])
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
pool = helper.make_node("MaxPool", ["C"], ["Y"], kernel_shape=[2, 2], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([conv, pool], "maxpool_after_conv", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("Generating GEMM tests:")
gemm_non_square()
gemm_with_bias()
gemm_transB()
gemm_alpha_beta()
gemm_small()
gemm_large()
gemm_transB_with_bias()
print("\nGenerating Conv tests:")
conv_3x3_kernel()
conv_stride2()
conv_multi_channel()
conv_1x1()
conv_same_padding_3x3()
conv_explicit_padding()
conv_with_bias_3x3()
conv_batch_2()
conv_large_spatial()
print("\nGenerating Pooling tests:")
maxpool_basic()
maxpool_stride2_multichannel()
maxpool_same_upper()
avgpool_basic()
avgpool_explicit_padding()
avgpool_include_pad()
maxpool_after_conv()
print("\nDone.")