Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f3c7464b4 | |||
| c77ffa9c56 | |||
| 495186503c | |||
| 2c1da813b5 | |||
| 8337a11ce9 | |||
| d136136d22 | |||
| 074eb183c7 | |||
| 43ed3914b8 | |||
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 |
+80
-12
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
|||||||
|
|
||||||
project(raptor)
|
project(raptor)
|
||||||
|
|
||||||
# Add symlink to PIM as accelerator in onnx-mlir
|
# Materialize a CMake shim directory
|
||||||
function(raptor_ensure_symlink link_path target_path)
|
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||||
|
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||||
|
|
||||||
if(NOT EXISTS "${link_parent}")
|
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
message(FATAL_ERROR
|
||||||
|
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||||
|
" ${real_external_source_dir}"
|
||||||
|
)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (IS_SYMLINK "${shim_dir}")
|
||||||
|
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||||
|
file(REMOVE "${shim_dir}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||||
|
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
file(MAKE_DIRECTORY "${shim_dir}")
|
||||||
|
|
||||||
|
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||||
|
set(shim_contents
|
||||||
|
"get_filename_component(raptor_external_source_dir
|
||||||
|
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||||
|
REALPATH
|
||||||
|
)
|
||||||
|
add_subdirectory(
|
||||||
|
\"\${raptor_external_source_dir}\"
|
||||||
|
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||||
|
)
|
||||||
|
if (DEFINED PIM_ENABLED)
|
||||||
|
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||||
|
endif ()
|
||||||
|
"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (EXISTS "${shim_file}")
|
||||||
|
file(READ "${shim_file}" old_contents)
|
||||||
|
else ()
|
||||||
|
set(old_contents "")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (NOT old_contents STREQUAL shim_contents)
|
||||||
|
file(WRITE "${shim_file}" "${shim_contents}")
|
||||||
|
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||||
|
else ()
|
||||||
|
message(STATUS "CMake shim already up to date for ${description}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
# Mirror the external tree's first-level entries into the shim directory
|
||||||
|
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||||
|
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||||
|
|
||||||
|
foreach (child IN LISTS children)
|
||||||
|
if (child STREQUAL "CMakeLists.txt")
|
||||||
|
continue()
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
set(real_child "${real_external_source_dir}/${child}")
|
||||||
|
set(shim_child "${shim_dir}/${child}")
|
||||||
|
|
||||||
|
if (IS_SYMLINK "${shim_child}")
|
||||||
|
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||||
|
if (existing_link_target STREQUAL real_child)
|
||||||
|
continue()
|
||||||
|
endif ()
|
||||||
|
file(REMOVE_RECURSE "${shim_child}")
|
||||||
|
elseif (EXISTS "${shim_child}")
|
||||||
|
# Do not delete real files/directories. This protects the generated shim.
|
||||||
|
continue()
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if(NOT EXISTS "${link_path}")
|
|
||||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
|
||||||
file(CREATE_LINK
|
file(CREATE_LINK
|
||||||
"${target_path}"
|
"${real_child}"
|
||||||
"${link_path}"
|
"${shim_child}"
|
||||||
SYMBOLIC
|
SYMBOLIC
|
||||||
)
|
)
|
||||||
endif()
|
endforeach ()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
raptor_ensure_symlink(
|
raptor_write_external_cmake_shim(
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||||
|
"PIM accelerator"
|
||||||
)
|
)
|
||||||
raptor_ensure_symlink(
|
|
||||||
|
raptor_write_external_cmake_shim(
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||||
|
"PIM accelerator tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patch onnx-mlir sources for PIM accelerator support.
|
# Patch onnx-mlir sources for PIM accelerator support.
|
||||||
|
|||||||
@@ -145,6 +145,46 @@ validate.py \
|
|||||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Each validation run writes debugging artifacts into the benchmark's workspace
|
||||||
|
directory (for example `validation/operations/gemm/small/`):
|
||||||
|
- `inputs/` — generated input CSVs used for the run.
|
||||||
|
- `outputs/` — reference outputs dumped by the native ONNX runner.
|
||||||
|
- `raptor/` — compiler artifacts:
|
||||||
|
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
|
||||||
|
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
|
||||||
|
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
|
||||||
|
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
|
||||||
|
`raptor/reports/` such as `dcp_merge_report.txt`,
|
||||||
|
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
|
||||||
|
- `runner/` — generated reference runner source, build tree, and shared library.
|
||||||
|
- `simulation/out.bin` — raw simulator output dump used for output comparison.
|
||||||
|
|
||||||
|
That means you usually do not need to rerun standalone `--EmitSpatial` or
|
||||||
|
`--EmitPim` commands while debugging validation failures: the per-pass dialect
|
||||||
|
dumps are already available under `raptor/dialects/`.
|
||||||
|
|
||||||
|
The validator does not currently expose a simulator tracing flag, but once a
|
||||||
|
validation has produced `raptor/pim/` you can rerun the simulator manually with
|
||||||
|
tracing enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend-simulators/pim/pim-simulator
|
||||||
|
cargo run --no-default-features --features tracing --release \
|
||||||
|
--package pim-simulator --bin pim-simulator -- \
|
||||||
|
-f /path/to/workspace/raptor/pim \
|
||||||
|
-o /path/to/workspace/simulation/out.bin \
|
||||||
|
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||||
|
```
|
||||||
|
|
||||||
|
With `--features tracing`, the simulator writes per-core traces as
|
||||||
|
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
|
||||||
|
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
|
||||||
|
and the model output shapes. If you need a clean slate before rerunning, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
validate.py --clean
|
||||||
|
```
|
||||||
|
|
||||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||||
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||||
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||||
|
|||||||
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
|||||||
if in_path.contains(&waiting_for) {
|
if in_path.contains(&waiting_for) {
|
||||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||||
let cycle = &path[cycle_start..];
|
let cycle = &path[cycle_start..];
|
||||||
|
let format_core = |core: &i32| (core - 1).to_string();
|
||||||
|
|
||||||
let cycle_str = cycle
|
let cycle_str = cycle
|
||||||
.iter()
|
.iter()
|
||||||
.map(|c| c.to_string())
|
.map(format_core)
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(" -> ");
|
.join(" -> ");
|
||||||
|
|
||||||
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
|||||||
.copied()
|
.copied()
|
||||||
.chain(std::iter::once(waiting_for))
|
.chain(std::iter::once(waiting_for))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
|
||||||
let states_msg = cycle
|
let states_msg = cycle
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|core| {
|
.filter_map(|core| {
|
||||||
states.get(core).map(|state| match state {
|
states.get(core).map(|state| match state {
|
||||||
CoreState::SendingTo(target, size) => {
|
CoreState::SendingTo(target, size) => {
|
||||||
format!("core {} send {}B -> {}", core, size, target)
|
format!("core {} send {}B -> {}", core - 1, size, target - 1)
|
||||||
}
|
}
|
||||||
CoreState::ReceivingFrom(source, size) => {
|
CoreState::ReceivingFrom(source, size) => {
|
||||||
format!("core {} recv {}B <- {}", core, size, source)
|
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
|
||||||
}
|
}
|
||||||
CoreState::Working => format!("core {} working", core),
|
CoreState::Working => format!("core {} working", core - 1),
|
||||||
CoreState::Halted => format!("core {} halted", core),
|
CoreState::Halted => format!("core {} halted", core - 1),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
|
|||||||
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
|||||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
||||||
|
set(PIM_GENERATED_PATH_SHIM_TARGET "")
|
||||||
|
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
|
||||||
|
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
|
||||||
|
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
|
||||||
|
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
|
||||||
|
|
||||||
|
function(add_pim_generated_path_shim relative_path)
|
||||||
|
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
|
||||||
|
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
|
||||||
|
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT "${shim_file}"
|
||||||
|
DEPENDS "${real_file}"
|
||||||
|
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
|
||||||
|
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
|
||||||
|
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
|
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
|
||||||
|
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
file(GLOB_RECURSE pim_generated_path_scan_sources
|
||||||
|
CONFIGURE_DEPENDS
|
||||||
|
"${PIM_SRC_ROOT}/*.cpp"
|
||||||
|
"${PIM_SRC_ROOT}/*.hpp"
|
||||||
|
)
|
||||||
|
|
||||||
|
set(pim_generated_path_shims)
|
||||||
|
foreach (source_file IN LISTS pim_generated_path_scan_sources)
|
||||||
|
file(READ "${source_file}" source_contents)
|
||||||
|
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
|
||||||
|
|
||||||
|
foreach (inc_match IN LISTS source_inc_matches)
|
||||||
|
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
|
||||||
|
list(APPEND pim_generated_path_shims "${relative_inc_path}")
|
||||||
|
endforeach ()
|
||||||
|
endforeach ()
|
||||||
|
|
||||||
|
list(REMOVE_DUPLICATES pim_generated_path_shims)
|
||||||
|
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
|
||||||
|
add_pim_generated_path_shim("${relative_inc_path}")
|
||||||
|
endforeach ()
|
||||||
|
|
||||||
|
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
|
||||||
|
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
|
||||||
|
endif ()
|
||||||
|
|
||||||
set(PIM_PUBLIC_INCLUDE_DIRS
|
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||||
${ONNX_MLIR_SRC_ROOT}/include
|
${ONNX_MLIR_SRC_ROOT}/include
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
|
|||||||
|
|
||||||
function(add_pim_library name)
|
function(add_pim_library name)
|
||||||
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||||
|
if (PIM_GENERATED_PATH_SHIM_TARGET)
|
||||||
|
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
|
||||||
|
endif ()
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
add_subdirectory(Dialect)
|
add_subdirectory(Dialect)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
add_pim_library(OMPimCommon
|
add_pim_library(OMPimCommon
|
||||||
IR/AddressAnalysis.cpp
|
IR/AddressAnalysis.cpp
|
||||||
|
IR/ConstantUtils.cpp
|
||||||
IR/CoreBlockUtils.cpp
|
IR/CoreBlockUtils.cpp
|
||||||
IR/EntryPointUtils.cpp
|
IR/EntryPointUtils.cpp
|
||||||
IR/ShapeUtils.cpp
|
IR/ShapeUtils.cpp
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
|
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||||
|
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto elementType = denseAttr.getElementType();
|
||||||
|
if (!elementType.isIndex() && !elementType.isInteger())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> indices;
|
||||||
|
indices.reserve(loadOp.getIndices().size());
|
||||||
|
for (mlir::Value index : loadOp.getIndices()) {
|
||||||
|
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||||
|
if (failed(resolvedIndex))
|
||||||
|
return mlir::failure();
|
||||||
|
indices.push_back(*resolvedIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||||
|
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||||
|
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
value = resolveAlias(value, knowledge);
|
value = resolveAlias(value, knowledge);
|
||||||
@@ -126,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
|||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||||
|
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||||
|
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
|
#include "ConstantUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||||
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
|
||||||
|
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||||
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||||
|
return current->getBlock();
|
||||||
|
|
||||||
|
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||||
|
return &funcOp.getBody().front();
|
||||||
|
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||||
|
return moduleOp.getBody();
|
||||||
|
return anchorOp->getBlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
||||||
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||||
|
for (Operation& op : *hostBlock) {
|
||||||
|
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||||
|
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||||
|
continue;
|
||||||
|
return constantOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||||
|
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||||
|
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||||
|
mlir::Attribute value,
|
||||||
|
mlir::Type type,
|
||||||
|
mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
@@ -30,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
|
|||||||
for (mlir::Operation& op : block) {
|
for (mlir::Operation& op : block) {
|
||||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||||
continue;
|
continue;
|
||||||
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||||
|
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||||
|
continue;
|
||||||
|
|
||||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||||
mlir::Block& loopBody = forOp.getRegion().front();
|
mlir::Block& loopBody = forOp.getRegion().front();
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|
||||||
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
|||||||
return numElements;
|
return numElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasByteSizedElementType(mlir::Type elementType) {
|
||||||
|
if (mlir::isa<mlir::IndexType>(elementType))
|
||||||
|
return true;
|
||||||
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||||
|
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
||||||
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||||
|
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||||
|
if (mlir::isa<mlir::IndexType>(elementType))
|
||||||
|
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
||||||
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||||
|
return static_cast<size_t>(intType.getWidth() / 8);
|
||||||
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||||
|
return static_cast<size_t>(floatType.getWidth() / 8);
|
||||||
|
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
||||||
|
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
|
|||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool hasByteSizedElementType(mlir::Type elementType);
|
||||||
|
|
||||||
|
size_t getElementTypeSizeInBytes(mlir::Type elementType);
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
|||||||
@@ -21,12 +21,15 @@ namespace {
|
|||||||
|
|
||||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||||
|
if (!weightArg)
|
||||||
|
return false;
|
||||||
bool found = false;
|
bool found = false;
|
||||||
parentOp.walk([&](mlir::Operation* op) {
|
parentOp.walk([&](mlir::Operation* op) {
|
||||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
found |= mvmOp.getWeight() == *weightArg;
|
||||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
found |= vmmOp.getWeight() == *weightArg;
|
||||||
});
|
});
|
||||||
return found;
|
return found;
|
||||||
}
|
}
|
||||||
@@ -35,13 +38,19 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
|||||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
auto weights = parentOp.getWeights();
|
auto weights = parentOp.getWeights();
|
||||||
llvm::SmallSet<unsigned, 8> visited;
|
llvm::SmallSet<unsigned, 8> visited;
|
||||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
auto walkWeight = [&](mlir::Value weight) {
|
||||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||||
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||||
|
if (!weightArg || *weightArg != weight)
|
||||||
|
continue;
|
||||||
|
if (visited.insert(weightIndex).second)
|
||||||
callback(parentOp->getOpOperand(weightIndex));
|
callback(parentOp->getOpOperand(weightIndex));
|
||||||
|
break;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -90,18 +99,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
|||||||
assert(root && "expected valid root op");
|
assert(root && "expected valid root op");
|
||||||
root->walk([&](pim::PimCoreOp coreOp) {
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
auto weights = coreOp.getWeights();
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||||
if (weightIndex < weights.size())
|
|
||||||
callback(coreOp->getOpOperand(weightIndex));
|
callback(coreOp->getOpOperand(weightIndex));
|
||||||
|
break;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
auto weights = coreBatchOp.getWeights();
|
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
for (auto weight : weights)
|
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||||
for (mlir::OpOperand& use : weight.getUses())
|
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||||
if (use.getOwner() == coreBatchOp.getOperation())
|
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||||
callback(use);
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|||||||
@@ -13,7 +13,8 @@
|
|||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
struct CappedDiagnosticReporter {
|
struct CappedDiagnosticReporter {
|
||||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||||
|
: maxReportedFailures(maxReportedFailures) {}
|
||||||
|
|
||||||
template <typename EmitFn>
|
template <typename EmitFn>
|
||||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||||
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
|
|||||||
|
|
||||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||||
if (numFailures > maxReportedFailures)
|
if (numFailures > maxReportedFailures)
|
||||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||||
<< failureDescription;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasFailure() const { return numFailures != 0; }
|
bool hasFailure() const { return numFailures != 0; }
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
|
|
||||||
@@ -24,113 +28,166 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
|||||||
return laneCoreIds;
|
return laneCoreIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
|
||||||
IRRewriter rewriter(scalarCore.getContext());
|
if (Value mapped = mapper.lookupOrNull(value))
|
||||||
SmallVector<Operation*> batchOps;
|
return mapped;
|
||||||
scalarCore.walk([&](Operation* op) {
|
|
||||||
if (isa<pim::PimSendBatchOp,
|
|
||||||
pim::PimSendTensorBatchOp,
|
|
||||||
pim::PimReceiveBatchOp,
|
|
||||||
pim::PimReceiveTensorBatchOp,
|
|
||||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
|
||||||
batchOps.push_back(op);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
for (Operation* op : batchOps) {
|
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
|
||||||
rewriter.setInsertionPoint(op);
|
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
|
||||||
|
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
assert(definingOp && "expected captured value to be defined by an operation");
|
||||||
|
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
|
||||||
|
|
||||||
|
for (Value operand : definingOp->getOperands())
|
||||||
|
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||||
|
|
||||||
|
Operation* cloned = builder.clone(*definingOp, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
return mapper.lookup(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||||
|
pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
OperationFolder& constantFolder) {
|
||||||
|
Block& oldBlock = coreBatchOp.getBody().front();
|
||||||
|
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||||
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||||
|
size_t weightCount = coreBatchOp.getWeights().size();
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||||
|
if (blockArg.getType().isIndex()) {
|
||||||
|
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argIndex <= weightCount) {
|
||||||
|
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
|
||||||
|
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t inputIndex = argIndex - 1 - weightCount;
|
||||||
|
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
||||||
|
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : oldBlock) {
|
||||||
|
if (isa<pim::PimHaltOp>(op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (Value operand : op.getOperands())
|
||||||
|
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||||
|
|
||||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||||
pim::PimSendOp::create(rewriter,
|
pim::PimSendOp::create(
|
||||||
|
builder,
|
||||||
sendBatchOp.getLoc(),
|
sendBatchOp.getLoc(),
|
||||||
sendBatchOp.getInput(),
|
mapper.lookup(sendBatchOp.getInput()),
|
||||||
sendBatchOp.getSizeAttr(),
|
sendBatchOp.getSizeAttr(),
|
||||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
||||||
rewriter.eraseOp(op);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||||
pim::PimSendTensorOp::create(
|
pim::PimSendTensorOp::create(
|
||||||
rewriter,
|
builder,
|
||||||
sendTensorBatchOp.getLoc(),
|
sendTensorBatchOp.getLoc(),
|
||||||
sendTensorBatchOp.getInput(),
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||||
rewriter.eraseOp(op);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||||
auto scalarReceive =
|
auto scalarReceive = pim::PimReceiveOp::create(
|
||||||
pim::PimReceiveOp::create(rewriter,
|
builder,
|
||||||
receiveBatchOp.getLoc(),
|
receiveBatchOp.getLoc(),
|
||||||
receiveBatchOp.getOutput().getType(),
|
receiveBatchOp.getOutput().getType(),
|
||||||
receiveBatchOp.getOutputBuffer(),
|
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||||
receiveBatchOp.getSizeAttr(),
|
receiveBatchOp.getSizeAttr(),
|
||||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
||||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||||
rewriter,
|
builder,
|
||||||
receiveTensorBatchOp.getLoc(),
|
receiveTensorBatchOp.getLoc(),
|
||||||
receiveTensorBatchOp.getOutput().getType(),
|
receiveTensorBatchOp.getOutput().getType(),
|
||||||
receiveTensorBatchOp.getOutputBuffer(),
|
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
||||||
|
builder,
|
||||||
memcpBatchOp.getLoc(),
|
memcpBatchOp.getLoc(),
|
||||||
memcpBatchOp.getOutput().getType(),
|
memcpBatchOp.getOutput().getType(),
|
||||||
memcpBatchOp.getDeviceTarget(),
|
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||||
memcpBatchOp.getHostSource(),
|
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||||
memcpBatchOp.getSizeAttr());
|
memcpBatchOp.getSizeAttr());
|
||||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = builder.clone(op, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||||
unsigned lane,
|
ArrayRef<unsigned> lanes,
|
||||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||||
|
assert(!lanes.empty() && "expected at least one batch lane");
|
||||||
|
|
||||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||||
OpBuilder builder(scratchModule->getContext());
|
OpBuilder builder(scratchModule->getContext());
|
||||||
|
OperationFolder constantFolder(scratchModule->getContext());
|
||||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||||
|
|
||||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
||||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
|
||||||
SmallVector<Value> laneWeights;
|
|
||||||
laneWeights.reserve(weightsPerLane);
|
|
||||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
|
||||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
|
||||||
|
|
||||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||||
auto scalarCore = pim::PimCoreOp::create(
|
int32_t coreId = coreIds[lanes.front()];
|
||||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
for (unsigned lane : lanes)
|
||||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
||||||
IRMapping mapper;
|
|
||||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
|
||||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
|
||||||
|
|
||||||
builder.setInsertionPointToEnd(block);
|
auto scalarCore =
|
||||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||||
Operation* cloned = builder.clone(op, mapper);
|
SmallVector<Type> weightTypes;
|
||||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
SmallVector<Location> weightLocs;
|
||||||
mapper.map(originalResult, clonedResult);
|
weightTypes.reserve(weights.size());
|
||||||
|
weightLocs.reserve(weights.size());
|
||||||
|
for (Value weight : weights) {
|
||||||
|
weightTypes.push_back(weight.getType());
|
||||||
|
weightLocs.push_back(weight.getLoc());
|
||||||
}
|
}
|
||||||
|
Block* block =
|
||||||
|
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
|
||||||
|
builder.setInsertionPointToEnd(block);
|
||||||
|
for (unsigned lane : lanes)
|
||||||
|
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
|
||||||
return callback(scalarCore);
|
return callback(scalarCore);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
unsigned lane,
|
||||||
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||||
|
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -9,5 +9,8 @@ namespace onnx_mlir {
|
|||||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||||
unsigned lane,
|
unsigned lane,
|
||||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||||
|
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||||
|
llvm::ArrayRef<unsigned> lanes,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -41,15 +41,10 @@ using namespace mlir;
|
|||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
|
|
||||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
|
||||||
auto type = cast<ShapedType>(value.getType());
|
|
||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||||
auto type = cast<ShapedType>(value.getType());
|
auto type = cast<ShapedType>(value.getType());
|
||||||
assert("Only static shape is supported" && type.hasStaticShape());
|
assert("Only static shape is supported" && type.hasStaticShape());
|
||||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||||
MemEntry memEntry = {0, allocSize};
|
MemEntry memEntry = {0, allocSize};
|
||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
@@ -398,20 +393,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
||||||
|
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
||||||
|
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
||||||
|
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
||||||
emitMemCopyOp("ld",
|
emitMemCopyOp("ld",
|
||||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||||
loadOp.getDeviceTargetOffset(),
|
*deviceTargetOffset,
|
||||||
addressOf(loadOp.getHostSource(), knowledge),
|
addressOf(loadOp.getHostSource(), knowledge),
|
||||||
loadOp.getHostSourceOffset(),
|
*hostSourceOffset,
|
||||||
loadOp.getSize());
|
loadOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||||
|
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||||
|
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
||||||
|
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
||||||
emitMemCopyOp("st",
|
emitMemCopyOp("st",
|
||||||
addressOf(storeOp.getHostTarget(), knowledge),
|
addressOf(storeOp.getHostTarget(), knowledge),
|
||||||
storeOp.getHostTargetOffset(),
|
*hostTargetOffset,
|
||||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||||
storeOp.getDeviceSourceOffset(),
|
*deviceSourceOffset,
|
||||||
storeOp.getSize());
|
storeOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,25 +429,30 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||||
emitCommunicationOp(
|
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
||||||
|
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||||
const StaticValueKnowledge& knowledge) const {
|
const StaticValueKnowledge& knowledge) const {
|
||||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||||
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
||||||
|
/ receiveTensorOp.getSourceCoreIds().size();
|
||||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||||
|
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||||
|
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||||
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
||||||
|
/ sendTensorOp.getTargetCoreIds().size();
|
||||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||||
}
|
}
|
||||||
@@ -455,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
|
|||||||
|
|
||||||
int64_t axis = concatOp.getAxis();
|
int64_t axis = concatOp.getAxis();
|
||||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
|
||||||
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||||
|
|
||||||
size_t outerCount = 1;
|
size_t outerCount = 1;
|
||||||
@@ -507,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -537,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
|||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 2;
|
instruction.r2OrImm = 2;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -582,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
|||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.r2OrImm = 1;
|
instruction.r2OrImm = 1;
|
||||||
instruction.generic1 = 1;
|
instruction.generic1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -595,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -608,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -621,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
|||||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput()));
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -634,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
|||||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||||
instruction.rd = 0;
|
instruction.rd = 0;
|
||||||
instruction.r1 = 1;
|
instruction.r1 = 1;
|
||||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput()));
|
instruction.generic3 =
|
||||||
|
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -647,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
|||||||
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||||
auto srcShape = srcType.getShape();
|
auto srcShape = srcType.getShape();
|
||||||
size_t rank = srcShape.size();
|
size_t rank = srcShape.size();
|
||||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
|
||||||
size_t totalElements = srcType.getNumElements();
|
size_t totalElements = srcType.getNumElements();
|
||||||
|
|
||||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||||
@@ -728,12 +737,19 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
|
|
||||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||||
SmallVector<unsigned, 8> indices;
|
SmallVector<unsigned, 8> indices;
|
||||||
auto addIndex = [&](unsigned weightIndex) {
|
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||||
|
auto addWeight = [&](mlir::Value weight) {
|
||||||
|
if (!coreOp)
|
||||||
|
return;
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||||
|
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||||
|
continue;
|
||||||
if (!llvm::is_contained(indices, weightIndex))
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
@@ -795,6 +811,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
|||||||
/// fully resolved before the JSON instructions are emitted.
|
/// fully resolved before the JSON instructions are emitted.
|
||||||
/// Returns the number of emitted instructions, or -1 on failure.
|
/// Returns the number of emitted instructions, or -1 on failure.
|
||||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||||
|
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
||||||
|
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
||||||
|
if (!coreOp)
|
||||||
|
return std::nullopt;
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||||
|
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||||
|
return weightIndex;
|
||||||
|
return std::nullopt;
|
||||||
|
};
|
||||||
size_t processedOperations = 0;
|
size_t processedOperations = 0;
|
||||||
auto result =
|
auto result =
|
||||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
@@ -814,8 +839,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
auto weightIndex = resolveWeightIndex(vmmOp);
|
||||||
|
if (!weightIndex)
|
||||||
|
return failure();
|
||||||
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
||||||
|
}
|
||||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
@@ -1004,10 +1033,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
reportedCoreIds.reserve(batchCoreIds.size());
|
reportedCoreIds.reserve(batchCoreIds.size());
|
||||||
MemoryReportRow batchRow;
|
MemoryReportRow batchRow;
|
||||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||||
|
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
||||||
|
SmallVector<size_t> orderedOriginalCoreIds;
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
|
||||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
|
||||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||||
|
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
||||||
|
if (inserted)
|
||||||
|
orderedOriginalCoreIds.push_back(originalCoreId);
|
||||||
|
it->second.push_back(lane);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||||
|
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||||
|
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
||||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||||
MemoryReportRow laneRow;
|
MemoryReportRow laneRow;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
|
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -15,8 +15,8 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
|||||||
llvm::cl::init(EmitPimCodegen),
|
llvm::cl::init(EmitPimCodegen),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
llvm::cl::opt<PimMergeSchedulerType>
|
||||||
"pim-merge-scheduler",
|
pimMergeScheduler("pim-merge-scheduler",
|
||||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||||
|
|||||||
@@ -128,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
|
|
||||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||||
SmallVector<unsigned, 8> indices;
|
SmallVector<unsigned, 8> indices;
|
||||||
auto addIndex = [&](unsigned weightIndex) {
|
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||||
|
auto addWeight = [&](mlir::Value weight) {
|
||||||
|
if (!coreOp)
|
||||||
|
return;
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||||
|
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||||
|
continue;
|
||||||
if (!llvm::is_contained(indices, weightIndex))
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
@@ -200,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
int64_t numCols = shape[1];
|
int64_t numCols = shape[1];
|
||||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||||
|
|
||||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
|
|||||||
@@ -18,13 +18,17 @@ namespace detail {
|
|||||||
|
|
||||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||||
|
|
||||||
|
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||||
|
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
template <typename Fn, size_t... Is>
|
||||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
template <typename Fn, size_t... Is>
|
||||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||||
return std::forward<Fn>(fn)(values[Is]...);
|
return std::forward<Fn>(fn)(values[Is]...);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value weight : weights)
|
||||||
|
block->addArgument(weight.getType(), loc);
|
||||||
for (mlir::Value input : inputs)
|
for (mlir::Value input : inputs)
|
||||||
block->addArgument(input.getType(), loc);
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
@@ -93,14 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
|
|
||||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||||
|
detail::getInputBlockArgs(block, weights.size()),
|
||||||
|
std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return computeOp;
|
return computeOp;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto bodyResult =
|
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
detail::getInputBlockArgs(block, weights.size()),
|
||||||
|
std::make_index_sequence<NumInputs> {});
|
||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
@@ -123,6 +132,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value weight : weights)
|
||||||
|
block->addArgument(weight.getType(), loc);
|
||||||
for (mlir::Value input : inputs)
|
for (mlir::Value input : inputs)
|
||||||
block->addArgument(input.getType(), loc);
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
@@ -131,13 +142,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
|
|
||||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return computeOp;
|
return computeOp;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||||
if (!computes.empty())
|
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||||
|
if (!computes.empty() || !computeBatches.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||||
@@ -190,16 +191,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
|
||||||
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
|
||||||
|
|
||||||
RewritePatternSet earlyPostPatterns(ctx);
|
|
||||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
|
||||||
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
|
||||||
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
PassManager cleanupPM(ctx);
|
PassManager cleanupPM(ctx);
|
||||||
cleanupPM.addPass(createCanonicalizerPass());
|
cleanupPM.addPass(createCanonicalizerPass());
|
||||||
|
|||||||
@@ -402,13 +402,33 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp = createSpatCompute(
|
auto computeOp =
|
||||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||||
|
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||||
|
for (Value weight : weights) {
|
||||||
|
blockArgTypes.push_back(weight.getType());
|
||||||
|
blockArgLocs.push_back(gemmLoc);
|
||||||
|
}
|
||||||
|
for (Value input : aHSlices[coreId]) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(gemmLoc);
|
||||||
|
}
|
||||||
|
Block* body =
|
||||||
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
SmallVector<Value> vmmOutputs;
|
SmallVector<Value> vmmOutputs;
|
||||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
|
||||||
vmmOutputs.push_back(
|
auto weightArg = computeOp.getWeightArgument(aHSliceId);
|
||||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||||
|
if (!weightArg || !inputArg)
|
||||||
|
return failure();
|
||||||
|
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
|
||||||
|
}
|
||||||
if (vmmOutputs.empty()) {
|
if (vmmOutputs.empty()) {
|
||||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -416,10 +436,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
return success();
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
});
|
|
||||||
if (failed(computeOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
partialResults.push_back(computeOp->getResult(0));
|
partialResults.push_back(computeOp->getResult(0));
|
||||||
}
|
}
|
||||||
@@ -530,37 +547,49 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
sharedBias = c;
|
sharedBias = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
|
||||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
|
||||||
|
|
||||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
|
||||||
|
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
TypeRange(resultTypes),
|
TypeRange {outType},
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||||
ValueRange(weights),
|
ValueRange {b},
|
||||||
ValueRange(aSlices));
|
ValueRange {a});
|
||||||
|
|
||||||
Block* body = rewriter.createBlock(
|
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
SmallVector<Location> blockArgLocs(4, loc);
|
||||||
|
Block* body =
|
||||||
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
rewriter.setInsertionPointToEnd(body);
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
auto lane = batchOp.getLaneArgument();
|
||||||
|
auto weight = batchOp.getWeightArgument(0);
|
||||||
|
auto packedInput = batchOp.getInputArgument(0);
|
||||||
|
auto packedOutput = batchOp.getOutputArgument(0);
|
||||||
|
if (!lane || !weight || !packedInput || !packedOutput)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value row =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
|
||||||
Value laneResult = vmmResult;
|
Value laneResult = vmmResult;
|
||||||
if (sharedBias)
|
if (sharedBias)
|
||||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
|
||||||
|
|
||||||
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||||
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||||
|
tensor::ParallelInsertSliceOp::create(
|
||||||
|
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
|
||||||
});
|
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
|
|||||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value collapseBatchDims(Value value,
|
static Value
|
||||||
int64_t batchSize,
|
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||||
int64_t rows,
|
|
||||||
int64_t cols,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
if (type.getRank() == 2 || type.getRank() == 3)
|
if (type.getRank() == 2 || type.getRank() == 3)
|
||||||
return value;
|
return value;
|
||||||
|
|
||||||
auto collapsedType =
|
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||||
SmallVector<ReassociationIndices> reassociation = {
|
|
||||||
ReassociationIndices {},
|
|
||||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
||||||
};
|
|
||||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||||
reassociation.front().push_back(dim);
|
reassociation.front().push_back(dim);
|
||||||
|
|
||||||
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
|
|||||||
return collapseCompute.getResult(0);
|
return collapseCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value expandBatchDims(Value value,
|
static Value
|
||||||
RankedTensorType outputType,
|
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
||||||
size_t batchRank,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||||
return value;
|
return value;
|
||||||
|
|
||||||
SmallVector<ReassociationIndices> reassociation = {
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||||
ReassociationIndices {},
|
|
||||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
||||||
};
|
|
||||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
|
||||||
|
|||||||
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
|
|||||||
|
|
||||||
Value outputC = channelLoop.getInductionVar();
|
Value outputC = channelLoop.getInductionVar();
|
||||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||||
Value inputC =
|
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
|
||||||
|
|
||||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||||
|
|
||||||
Value outputH = heightLoop.getInductionVar();
|
Value outputH = heightLoop.getInductionVar();
|
||||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||||
Value inputH =
|
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
|
||||||
|
|
||||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||||
|
|
||||||
Value outputW = widthLoop.getInductionVar();
|
Value outputW = widthLoop.getInductionVar();
|
||||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||||
Value inputW =
|
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||||
Value inputSlice =
|
Value inputSlice =
|
||||||
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
|||||||
|
|
||||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||||
|| resizeOp.getNearestMode() != "floor")
|
|| resizeOp.getNearestMode() != "floor")
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(resizeOp,
|
||||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
"resize lowering currently supports only nearest + asymmetric + floor.");
|
||||||
|
|
||||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||||
|
|||||||
@@ -27,66 +27,26 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
|||||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||||
|
return arg && canPromoteInputBlockArgument(*arg);
|
||||||
|
}
|
||||||
|
|
||||||
static bool isDirectConstantValue(Value value) {
|
static bool isDirectConstantValue(Value value) {
|
||||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ComputeOpTy>
|
template <typename ComputeOpTy>
|
||||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||||
Block& block = compute.getBody().front();
|
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
if (inputIdx >= block.getNumArguments())
|
|
||||||
continue;
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
if (!isWeightLikeComputeOperand(input))
|
||||||
continue;
|
continue;
|
||||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||||
continue;
|
continue;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
|
||||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
|
||||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
|
||||||
if (batchOp.getLaneCount() != 1)
|
|
||||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
|
||||||
|
|
||||||
auto loc = batchOp.getLoc();
|
|
||||||
rewriter.setInsertionPoint(batchOp);
|
|
||||||
auto computeOp =
|
|
||||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
|
||||||
computeOp.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
|
||||||
|
|
||||||
Block& templateBlock = batchOp.getBody().front();
|
|
||||||
SmallVector<Type> blockArgTypes;
|
|
||||||
SmallVector<Location> blockArgLocs;
|
|
||||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
|
||||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
|
||||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
|
||||||
blockArgTypes.push_back(arg.getType());
|
|
||||||
blockArgLocs.push_back(loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* newBlock =
|
|
||||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
|
||||||
mapper.map(oldArg, newArg);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
|
||||||
for (Operation& op : templateBlock)
|
|
||||||
rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
|
||||||
rewriter.eraseOp(batchOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||||
@@ -96,11 +56,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
bool needsRewrite = false;
|
bool needsRewrite = false;
|
||||||
Block& oldBlock = compute.getBody().front();
|
Block& oldBlock = compute.getBody().front();
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
if (inputIdx >= oldBlock.getNumArguments())
|
|
||||||
continue;
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
if (!isWeightLikeComputeOperand(input))
|
||||||
continue;
|
continue;
|
||||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||||
continue;
|
continue;
|
||||||
promoteInput[inputIdx] = true;
|
promoteInput[inputIdx] = true;
|
||||||
needsRewrite = true;
|
needsRewrite = true;
|
||||||
@@ -131,8 +89,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||||
auto* newBlock =
|
SmallVector<Type> newBlockArgTypes;
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
SmallVector<Location> newBlockArgLocs;
|
||||||
|
for (Value weight : newWeights) {
|
||||||
|
newBlockArgTypes.push_back(weight.getType());
|
||||||
|
newBlockArgLocs.push_back(weight.getLoc());
|
||||||
|
}
|
||||||
|
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||||
|
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||||
|
auto* newBlock = rewriter.createBlock(
|
||||||
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
@@ -141,17 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||||
|
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||||
|
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||||
|
if (!oldWeightArg || !newWeightArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||||
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
|
}
|
||||||
size_t newInputIdx = 0;
|
size_t newInputIdx = 0;
|
||||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||||
|
if (!oldArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||||
if (!promoteInput[oldInputIdx]) {
|
if (!promoteInput[oldInputIdx]) {
|
||||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||||
|
if (!newInputArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||||
|
mapper.map(*oldArg, *newInputArg);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||||
if (failed(clonedValue))
|
if (failed(clonedValue))
|
||||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||||
mapper.map(oldArg, *clonedValue);
|
mapper.map(*oldArg, *clonedValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Operation& op : oldBlock.without_terminator())
|
for (Operation& op : oldBlock.without_terminator())
|
||||||
@@ -180,11 +159,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
bool needsRewrite = false;
|
bool needsRewrite = false;
|
||||||
Block& oldBlock = compute.getBody().front();
|
Block& oldBlock = compute.getBody().front();
|
||||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
if (inputIdx >= oldBlock.getNumArguments())
|
|
||||||
continue;
|
|
||||||
if (!isWeightLikeComputeOperand(input))
|
if (!isWeightLikeComputeOperand(input))
|
||||||
continue;
|
continue;
|
||||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||||
continue;
|
continue;
|
||||||
promoteInput[inputIdx] = true;
|
promoteInput[inputIdx] = true;
|
||||||
needsRewrite = true;
|
needsRewrite = true;
|
||||||
@@ -220,8 +197,31 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||||
newWeights,
|
newWeights,
|
||||||
newInputs);
|
newInputs);
|
||||||
auto* newBlock =
|
auto laneArg = compute.getLaneArgument();
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
if (!laneArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||||
|
SmallVector<Type> newBlockArgTypes;
|
||||||
|
SmallVector<Location> newBlockArgLocs;
|
||||||
|
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||||
|
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||||
|
newBlockArgTypes.push_back(laneArg->getType());
|
||||||
|
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||||
|
for (Value weight : newWeights) {
|
||||||
|
newBlockArgTypes.push_back(weight.getType());
|
||||||
|
newBlockArgLocs.push_back(weight.getLoc());
|
||||||
|
}
|
||||||
|
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||||
|
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||||
|
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||||
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||||
|
if (!outputArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||||
|
newBlockArgTypes.push_back(resultType);
|
||||||
|
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* newBlock = rewriter.createBlock(
|
||||||
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
@@ -230,31 +230,45 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
auto newLaneArg = newCompute.getLaneArgument();
|
||||||
|
if (!newLaneArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||||
|
mapper.map(*laneArg, *newLaneArg);
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||||
|
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||||
|
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||||
|
if (!oldWeightArg || !newWeightArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||||
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
|
}
|
||||||
size_t newInputIdx = 0;
|
size_t newInputIdx = 0;
|
||||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||||
|
if (!oldArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||||
if (!promoteInput[oldInputIdx]) {
|
if (!promoteInput[oldInputIdx]) {
|
||||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||||
|
if (!newInputArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||||
|
mapper.map(*oldArg, *newInputArg);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||||
if (failed(clonedValue))
|
if (failed(clonedValue))
|
||||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||||
mapper.map(oldArg, *clonedValue);
|
mapper.map(*oldArg, *clonedValue);
|
||||||
|
}
|
||||||
|
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||||
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||||
|
if (!outputArg)
|
||||||
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||||
|
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Operation& op : oldBlock.without_terminator())
|
for (Operation& op : oldBlock)
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
||||||
SmallVector<Value> newYieldOperands;
|
|
||||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
|
||||||
for (Value operand : oldYield.getOutputs()) {
|
|
||||||
auto mapped = mapper.lookupOrNull(operand);
|
|
||||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
|
||||||
}
|
|
||||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
|
||||||
|
|
||||||
rewriter.replaceOp(compute, newCompute.getResults());
|
rewriter.replaceOp(compute, newCompute.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -262,10 +276,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||||
}
|
}
|
||||||
@@ -277,8 +287,6 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|||||||
@@ -7,14 +7,10 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||||
|
|
||||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
@@ -17,6 +20,37 @@ namespace {
|
|||||||
|
|
||||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||||
|
SmallVector<int32_t> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (Value value : values) {
|
||||||
|
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
||||||
|
if (failed(constantValue))
|
||||||
|
return failure();
|
||||||
|
constants.push_back(*constantValue);
|
||||||
|
}
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
|
return operandIndex == 2;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||||
|
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
|
||||||
|
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
@@ -28,27 +62,30 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
|||||||
return coreIds;
|
return coreIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||||
IRMapping& mapper,
|
IRMapping& mapper,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
SmallVector<int32_t> targetCoreIds;
|
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
||||||
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
if (failed(targetCoreIds))
|
||||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
||||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
for (int32_t& targetCoreId : *targetCoreIds)
|
||||||
|
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||||
|
|
||||||
pim::PimSendTensorBatchOp::create(rewriter,
|
pim::PimSendTensorBatchOp::create(rewriter,
|
||||||
sendTensorBatchOp.getLoc(),
|
sendTensorBatchOp.getLoc(),
|
||||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||||
IRMapping& mapper,
|
IRMapping& mapper,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
||||||
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
if (failed(sourceCoreIds))
|
||||||
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||||
|
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||||
|
|
||||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||||
@@ -56,26 +93,81 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc
|
|||||||
receiveTensorBatchOp.getLoc(),
|
receiveTensorBatchOp.getLoc(),
|
||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||||
|
if (!result.hasOneUse())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
|
||||||
|
if (!returnOp)
|
||||||
|
return failure();
|
||||||
|
return result.getUses().begin()->getOperandNumber();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||||
|
if (scale == 1)
|
||||||
|
return base;
|
||||||
|
|
||||||
|
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
|
||||||
|
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
|
tensor::ParallelInsertSliceOp insertSlice,
|
||||||
|
ShapedType destinationType,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
|
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
||||||
|
ArrayRef<int64_t> shape = destinationType.getShape();
|
||||||
|
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
|
||||||
|
Value totalOffset;
|
||||||
|
Location loc = insertSlice.getLoc();
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
||||||
|
int64_t scale = strides[dim] * elementBytes;
|
||||||
|
Value scaledOffset;
|
||||||
|
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||||
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
assert(intAttr && "expected integer offset attribute");
|
||||||
|
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
totalOffset =
|
||||||
|
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!totalOffset)
|
||||||
|
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
|
return totalOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
if (computeBatchOp.getNumResults() != 0)
|
|
||||||
return computeBatchOp.emitOpError(
|
|
||||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
|
||||||
|
|
||||||
Location loc = computeBatchOp.getLoc();
|
Location loc = computeBatchOp.getLoc();
|
||||||
Block& oldBlock = computeBatchOp.getBody().front();
|
Block& oldBlock = computeBatchOp.getBody().front();
|
||||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||||
if (oldYield.getNumOperands() != 0)
|
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
|
||||||
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
if (computeBatchOp.getNumResults() == 0) {
|
||||||
|
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||||
|
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||||
|
}
|
||||||
|
else if (!inParallelOp) {
|
||||||
|
return computeBatchOp.emitOpError(
|
||||||
|
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||||
SmallVector<Value> batchInputs;
|
SmallVector<Value> batchInputs;
|
||||||
if (!computeBatchOp.getInputs().empty())
|
if (!computeBatchOp.getInputs().empty())
|
||||||
@@ -91,9 +183,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
|
SmallVector<unsigned> returnOperandIndices;
|
||||||
|
if (computeBatchOp.getNumResults() != 0) {
|
||||||
|
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||||
|
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||||
|
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||||
|
if (failed(returnOperandIndex))
|
||||||
|
return computeBatchOp.emitOpError(
|
||||||
|
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||||
|
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes;
|
SmallVector<Type> blockArgTypes;
|
||||||
SmallVector<Location> blockArgLocs;
|
SmallVector<Location> blockArgLocs;
|
||||||
for (BlockArgument arg : oldBlock.getArguments()) {
|
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
|
||||||
|
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
|
||||||
blockArgTypes.push_back(arg.getType());
|
blockArgTypes.push_back(arg.getType());
|
||||||
blockArgLocs.push_back(arg.getLoc());
|
blockArgLocs.push_back(arg.getLoc());
|
||||||
}
|
}
|
||||||
@@ -102,7 +207,21 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
auto oldLaneArg = computeBatchOp.getLaneArgument();
|
||||||
|
if (!oldLaneArg)
|
||||||
|
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
|
||||||
|
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
|
||||||
|
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
|
||||||
|
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
|
||||||
|
if (!oldWeightArg)
|
||||||
|
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
|
||||||
|
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
|
||||||
|
}
|
||||||
|
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||||
|
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||||
|
if (!oldArg)
|
||||||
|
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
|
||||||
|
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
@@ -114,7 +233,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
rewriter.getI32IntegerAttr(0),
|
rewriter.getI32IntegerAttr(0),
|
||||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
mapper.map(oldArg, copied);
|
mapper.map(*oldArg, copied);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||||
@@ -136,26 +255,81 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
return copied;
|
return copied;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||||
|
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||||
|
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||||
|
if (hostOutputTensor)
|
||||||
|
return hostOutputTensor;
|
||||||
|
|
||||||
|
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
|
||||||
|
return hostOutputTensor;
|
||||||
|
};
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
for (Operation& op : oldBlock) {
|
for (Operation& op : oldBlock) {
|
||||||
if (isa<spatial::SpatYieldOp>(op))
|
if (isa<spatial::SpatYieldOp>(op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||||
|
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||||
|
if (!firstOutputArg)
|
||||||
|
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
|
||||||
|
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||||
|
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||||
|
if (!insertSlice)
|
||||||
|
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
|
||||||
|
|
||||||
|
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||||
|
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||||
|
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||||
|
|
||||||
|
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||||
|
if (resultIndex >= returnOperandIndices.size())
|
||||||
|
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||||
|
|
||||||
|
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||||
|
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||||
|
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||||
|
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||||
|
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
|
||||||
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
insertSlice.getLoc(),
|
||||||
|
hostTarget.getType(),
|
||||||
|
hostTargetOffset,
|
||||||
|
zeroOffset,
|
||||||
|
hostTarget,
|
||||||
|
mappedSource,
|
||||||
|
getTensorSizeInBytesAttr(rewriter, mappedSource));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||||
|
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||||
|
if (failed(targetCoreIds))
|
||||||
|
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
||||||
|
for (int32_t& targetCoreId : *targetCoreIds)
|
||||||
|
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||||
pim::PimSendBatchOp::create(rewriter,
|
pim::PimSendBatchOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
mapper.lookup(sendBatchOp.getInput()),
|
mapper.lookup(sendBatchOp.getInput()),
|
||||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||||
sendBatchOp.getTargetCoreIdsAttr());
|
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||||
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
||||||
|
return failure();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||||
|
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
||||||
|
if (failed(sourceCoreIds))
|
||||||
|
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||||
|
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||||
|
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||||
@@ -163,14 +337,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||||
receiveBatchOp.getSourceCoreIdsAttr())
|
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
mapper.map(receiveBatchOp.getOutput(), received);
|
mapper.map(receiveBatchOp.getOutput(), received);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||||
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
||||||
|
return failure();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,6 +353,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
auto clonedTensor = cloned->getResult(0);
|
auto clonedTensor = cloned->getResult(0);
|
||||||
|
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
|
||||||
|
mapper.map(toTensorOp.getResult(), clonedTensor);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
@@ -194,9 +373,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (Value operand : op.getOperands()) {
|
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||||
continue;
|
continue;
|
||||||
|
if (isExplicitHostOperand(&op, operandIndex))
|
||||||
|
continue;
|
||||||
|
|
||||||
Operation* definingOp = operand.getDefiningOp();
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
|
|||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
BatchCoreLoweringPatterns.cpp
|
BatchCoreLoweringPatterns.cpp
|
||||||
ChannelLoweringPatterns.cpp
|
ChannelLoweringPatterns.cpp
|
||||||
Cleanup.cpp
|
|
||||||
Common.cpp
|
Common.cpp
|
||||||
ComputeLikeRegionUtils.cpp
|
ComputeLikeRegionUtils.cpp
|
||||||
CoreLoweringPatterns.cpp
|
CoreLoweringPatterns.cpp
|
||||||
@@ -22,6 +21,8 @@ add_pim_library(OMSpatialToPim
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRSCFDialect
|
MLIRSCFDialect
|
||||||
|
MLIRSCFUtils
|
||||||
|
MLIRTransformUtils
|
||||||
MLIRTosaDialect
|
MLIRTosaDialect
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
@@ -12,15 +13,24 @@ namespace {
|
|||||||
|
|
||||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||||
|
SmallVector<int32_t> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (Value value : values) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||||
|
}
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||||
pim::PimSendOp::create(rewriter,
|
pim::PimSendOp::create(
|
||||||
op.getLoc(),
|
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
|
||||||
op.getInput(),
|
|
||||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
|
||||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
|||||||
op.getResult().getType(),
|
op.getResult().getType(),
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
op.getSourceCoreId())
|
||||||
.getOutput();
|
.getOutput();
|
||||||
rewriter.replaceOp(op, received);
|
rewriter.replaceOp(op, received);
|
||||||
return success();
|
return success();
|
||||||
@@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTens
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||||
SmallVector<int32_t> targetCoreIds;
|
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
||||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
if (failed(targetCoreIds))
|
||||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
||||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
for (int32_t& targetCoreId : *targetCoreIds)
|
||||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
targetCoreId = toPimCoreId(targetCoreId);
|
||||||
|
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelRecei
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
||||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
if (failed(sourceCoreIds))
|
||||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
||||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||||
|
sourceCoreId = toPimCoreId(sourceCoreId);
|
||||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||||
Value outputBuffer =
|
Value outputBuffer =
|
||||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||||
Value received =
|
Value received =
|
||||||
pim::PimReceiveTensorOp::create(
|
pim::PimReceiveTensorOp::create(
|
||||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
rewriter.replaceOp(op, received);
|
rewriter.replaceOp(op, received);
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -1,42 +0,0 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
|
||||||
while (!pendingOps.empty()) {
|
|
||||||
bool erasedAnyOp = false;
|
|
||||||
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
|
||||||
Operation* opToRemove = *it;
|
|
||||||
if (!opToRemove->use_empty()) {
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.eraseOp(opToRemove);
|
|
||||||
it = pendingOps.erase(it);
|
|
||||||
erasedAnyOp = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (erasedAnyOp)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (Operation* opToRemove : pendingOps) {
|
|
||||||
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
|
||||||
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
|
||||||
for (Operation* user : opToRemove->getUsers()) {
|
|
||||||
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
|
||||||
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
|
||||||
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
|||||||
return returnValue;
|
return returnValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
|
||||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ namespace onnx_mlir {
|
|||||||
*/
|
*/
|
||||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||||
|
|
||||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
|
||||||
|
|
||||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#include <cassert>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -29,7 +31,18 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
unsigned inputIndex,
|
unsigned inputIndex,
|
||||||
Value replacement) {
|
Value replacement) {
|
||||||
Block& body = owner->getRegion(0).front();
|
Block& body = owner->getRegion(0).front();
|
||||||
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
BlockArgument bodyArgument;
|
||||||
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||||
|
auto computeArg = compute.getInputArgument(inputIndex);
|
||||||
|
assert(computeArg && "expected compute input block argument");
|
||||||
|
bodyArgument = *computeArg;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||||
|
assert(batchArg && "expected compute_batch input block argument");
|
||||||
|
bodyArgument = *batchArg;
|
||||||
|
}
|
||||||
|
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||||
|
|
||||||
rewriter.startOpModification(owner);
|
rewriter.startOpModification(owner);
|
||||||
bodyArgument.replaceAllUsesWith(replacement);
|
bodyArgument.replaceAllUsesWith(replacement);
|
||||||
@@ -37,7 +50,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
compute.getInputsMutable().erase(inputIndex);
|
compute.getInputsMutable().erase(inputIndex);
|
||||||
else
|
else
|
||||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||||
body.eraseArgument(inputIndex);
|
body.eraseArgument(bodyArgIndex);
|
||||||
rewriter.finalizeOpModification(owner);
|
rewriter.finalizeOpModification(owner);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
static void
|
||||||
|
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
if (mapping.lookupOrNull(operand))
|
if (mapping.lookupOrNull(operand))
|
||||||
continue;
|
continue;
|
||||||
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
|||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||||
|
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp>(definingOp))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
@@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
|||||||
|
|
||||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||||
|
SmallVector<int32_t> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (Value value : values) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||||
|
}
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||||
@@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
||||||
|
IRRewriter& rewriter,
|
||||||
|
OperationFolder& constantFolder) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
return false;
|
return false;
|
||||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||||
@@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
Block& block = computeOp.getBody().front();
|
Block& block = computeOp.getBody().front();
|
||||||
if (block.getNumArguments() != 0)
|
if (block.getNumArguments() != computeOp.getWeights().size())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
@@ -110,8 +131,14 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
|
|
||||||
rewriter.setInsertionPoint(computeOp);
|
rewriter.setInsertionPoint(computeOp);
|
||||||
IRMapping mapping;
|
IRMapping mapping;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
|
||||||
|
auto weightArg = computeOp.getWeightArgument(weightIndex);
|
||||||
|
if (!weightArg)
|
||||||
|
return false;
|
||||||
|
mapping.map(*weightArg, weight);
|
||||||
|
}
|
||||||
for (Operation& op : block.without_terminator()) {
|
for (Operation& op : block.without_terminator()) {
|
||||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||||
mapping.map(originalResult, newResult);
|
mapping.map(originalResult, newResult);
|
||||||
@@ -125,15 +152,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
IRRewriter& rewriter,
|
||||||
state.operationsToRemove.push_back(op);
|
OperationFolder& constantFolder) {
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
@@ -143,21 +167,44 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
auto& block = computeOp.getRegion().front();
|
auto& block = computeOp.getRegion().front();
|
||||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||||
|
|
||||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||||
if (!receiveOp || blockArg.use_empty())
|
if (!blockArg)
|
||||||
continue;
|
return computeOp.emitOpError("expected compute input block arguments during lowering");
|
||||||
|
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
if (receiveOp && !blockArg->use_empty()) {
|
||||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||||
|
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
Value received =
|
||||||
Value received = PimReceiveOp::create(
|
PimReceiveOp::create(
|
||||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||||
.getOutput();
|
.getOutput();
|
||||||
blockArg.replaceAllUsesWith(received);
|
blockArg->replaceAllUsesWith(received);
|
||||||
markOpToRemove(state, receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||||
|
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||||
|
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||||
|
if (failed(sourceCoreIds))
|
||||||
|
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||||
|
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||||
|
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||||
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||||
|
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||||
|
Value received = PimReceiveTensorOp::create(rewriter,
|
||||||
|
receiveTensorOp.getLoc(),
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||||
|
.getOutput();
|
||||||
|
blockArg->replaceAllUsesWith(received);
|
||||||
|
markOpToRemove(receiveTensorOp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||||
@@ -167,9 +214,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
if (result.use_empty())
|
if (result.use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
|
||||||
ReturnPathLoweringResult returnPathResult =
|
ReturnPathLoweringResult returnPathResult =
|
||||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||||
return failure();
|
return failure();
|
||||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||||
@@ -193,15 +239,40 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
|||||||
if (!computeOp.getWeights().empty())
|
if (!computeOp.getWeights().empty())
|
||||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
auto coreOp = PimCoreOp::create(rewriter,
|
auto coreOp = PimCoreOp::create(
|
||||||
loc,
|
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||||
ValueRange(computeWeights),
|
rewriter.setInsertionPointToStart(&block);
|
||||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||||
if (!blockArg.use_empty())
|
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||||
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
if (!blockArg)
|
||||||
block.eraseArguments(0, block.getNumArguments());
|
return computeOp.emitOpError("expected compute input block arguments during input materialization");
|
||||||
|
if (blockArg->use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||||
|
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||||
|
if (!inputType)
|
||||||
|
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
|
||||||
|
auto copied =
|
||||||
|
PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputBuffer.getType(),
|
||||||
|
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||||
|
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||||
|
outputBuffer,
|
||||||
|
input,
|
||||||
|
getTensorSizeInBytesAttr(rewriter, input))
|
||||||
|
.getOutput();
|
||||||
|
blockArg->replaceAllUsesWith(copied);
|
||||||
|
}
|
||||||
|
if (!computeOp.getInputs().empty())
|
||||||
|
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||||
Block* tempComputeBlock = new Block();
|
Block* tempComputeBlock = new Block();
|
||||||
computeOp.getBody().push_back(tempComputeBlock);
|
computeOp.getBody().push_back(tempComputeBlock);
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
struct CoreLoweringState {
|
|
||||||
size_t& nextCoreId;
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
|
||||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
|
||||||
};
|
|
||||||
|
|
||||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -76,10 +76,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
auto BBArgIndex = *inputIndex;
|
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
if (!BBArgValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
if (BBArgValue->use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
@@ -89,16 +90,17 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
}
|
}
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
auto BBArgIndex = *inputIndex;
|
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
if (!BBArgValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
if (BBArgValue->use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
@@ -108,7 +110,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
}
|
}
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
{
|
{
|
||||||
@@ -143,170 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
|
||||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
|
||||||
Location loc = constantOp.getLoc();
|
|
||||||
|
|
||||||
if (hasWeightAlways(constantOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
|
||||||
if (isa<spatial::SpatCompute>(op))
|
|
||||||
return false;
|
|
||||||
if (isa<func::FuncOp>(op->getParentOp()))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
|
||||||
|
|
||||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
|
||||||
|
|
||||||
if (constRankedTensorType) {
|
|
||||||
mlir::MemRefType memRefType =
|
|
||||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
|
||||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
|
||||||
loc,
|
|
||||||
constantOp->getParentOfType<ModuleOp>(),
|
|
||||||
"const",
|
|
||||||
memRefType,
|
|
||||||
constantOp.getValueAttr(),
|
|
||||||
rewriter.getUnitAttr());
|
|
||||||
std::string argName = globalOp.getSymName().str();
|
|
||||||
|
|
||||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
||||||
|
|
||||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
||||||
auto constUsers = constUses.getOwner();
|
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
|
||||||
if (!inputIndex)
|
|
||||||
return failure();
|
|
||||||
auto BBArgIndex = *inputIndex;
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(
|
|
||||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
|
||||||
if (!inputIndex)
|
|
||||||
return failure();
|
|
||||||
auto BBArgIndex = *inputIndex;
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
|
||||||
spatComputeBatch.getOperation(),
|
|
||||||
BBArgIndex,
|
|
||||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
{
|
|
||||||
|
|
||||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
|
||||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
||||||
|
|
||||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
||||||
auto constUsers = constUses.getOwner();
|
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
|
||||||
if (!inputIndex)
|
|
||||||
return failure();
|
|
||||||
auto BBArgIndex = *inputIndex;
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(
|
|
||||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
|
||||||
if (!inputIndex)
|
|
||||||
return failure();
|
|
||||||
auto BBArgIndex = *inputIndex;
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(
|
|
||||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
|
||||||
}
|
|
||||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
if (!mapSpatComputeToConst.contains(parent)) {
|
|
||||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
|
||||||
}
|
|
||||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
|
||||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
|
||||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
|
||||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
|
||||||
}
|
|
||||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (constantOp->use_empty())
|
|
||||||
rewriter.eraseOp(constantOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
@@ -383,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||||
patterns.getContext());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -6,10 +6,12 @@
|
|||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -42,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
|
||||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
|
||||||
state.operationsToRemove.push_back(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||||
std::string name = baseName.str();
|
std::string name = baseName.str();
|
||||||
unsigned suffix = 0;
|
unsigned suffix = 0;
|
||||||
@@ -318,7 +315,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
static void
|
||||||
|
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
if (mapping.lookupOrNull(operand))
|
if (mapping.lookupOrNull(operand))
|
||||||
continue;
|
continue;
|
||||||
@@ -327,7 +325,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
|||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||||
|
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp>(definingOp))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
@@ -337,15 +340,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void cloneHelperChain(Value sourceValue,
|
||||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
ArrayRef<Operation*> helperChain,
|
||||||
|
IRRewriter& rewriter,
|
||||||
|
OperationFolder& constantFolder,
|
||||||
|
Value& clonedValue) {
|
||||||
IRMapping mapping;
|
IRMapping mapping;
|
||||||
mapping.map(sourceValue, sourceValue);
|
mapping.map(sourceValue, sourceValue);
|
||||||
clonedValue = sourceValue;
|
clonedValue = sourceValue;
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(sourceValue);
|
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||||
for (Operation* op : helperChain) {
|
for (Operation* op : helperChain) {
|
||||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
|
||||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
mapping.map(originalResult, newResult);
|
mapping.map(originalResult, newResult);
|
||||||
@@ -360,23 +366,26 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
|||||||
Value sourceValue,
|
Value sourceValue,
|
||||||
int32_t hostTargetOffset,
|
int32_t hostTargetOffset,
|
||||||
int32_t deviceSourceOffset,
|
int32_t deviceSourceOffset,
|
||||||
int32_t sizeInBytes) {
|
int32_t sizeInBytes,
|
||||||
|
OperationFolder& constantFolder) {
|
||||||
|
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||||
|
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||||
|
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
|
||||||
|
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
|
||||||
return PimMemCopyDevToHostOp::create(rewriter,
|
return PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
|
hostTargetOffsetValue,
|
||||||
|
deviceSourceOffsetValue,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
sourceValue,
|
sourceValue,
|
||||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
|
||||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||||
IRRewriter& rewriter,
|
|
||||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||||
Value currentReturnValue = returnValue;
|
Value currentReturnValue = returnValue;
|
||||||
@@ -411,70 +420,85 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = producerOp->getLoc();
|
||||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
OperationFolder constantFolder(producerOp->getContext());
|
||||||
|
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||||
|
|
||||||
if (auto returnUse = analyzeReturnUse(result)) {
|
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||||
Value storedValue = yieldValue;
|
Value currentStoredValue = storedValue;
|
||||||
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||||
for (Operation* op : returnUse->helperChain)
|
for (Operation* op : returnUse->helperChain)
|
||||||
markOpToRemove(state, op);
|
markOpToRemove(op);
|
||||||
|
|
||||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||||
if (auto storedOp = storedValue.getDefiningOp())
|
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||||
rewriter.setInsertionPointAfter(storedOp);
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||||
emitHostCopy(
|
emitHostCopy(rewriter,
|
||||||
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
loc,
|
||||||
|
outputTensor,
|
||||||
|
currentStoredValue,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
static_cast<int32_t>(storedType.getNumElements() * elementSize),
|
||||||
|
constantFolder);
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultUses = result.getUses();
|
auto resultUses = producedValue.getUses();
|
||||||
if (rangeLength(resultUses) == 1) {
|
if (rangeLength(resultUses) == 1) {
|
||||||
OpOperand& resultUse = *resultUses.begin();
|
OpOperand& resultUse = *resultUses.begin();
|
||||||
Operation* resultUser = resultUse.getOwner();
|
Operation* resultUser = resultUse.getOwner();
|
||||||
|
|
||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
emitHostCopy(
|
emitHostCopy(rewriter,
|
||||||
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
loc,
|
||||||
|
outputTensor,
|
||||||
|
storedValue,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||||
|
constantFolder);
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
|
|
||||||
if (concatReturnUse->helperChain.empty()) {
|
if (concatReturnUse->helperChain.empty()) {
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||||
emitHostCopy(rewriter,
|
emitHostCopy(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor,
|
outputTensor,
|
||||||
yieldValue,
|
storedValue,
|
||||||
static_cast<int32_t>(flatOffset * elementSize),
|
static_cast<int32_t>(flatOffset * elementSize),
|
||||||
0,
|
0,
|
||||||
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||||
|
constantFolder);
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||||
if (!storedType) {
|
if (!storedType) {
|
||||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
producerOp->emitOpError(
|
||||||
|
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||||
@@ -484,7 +508,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
|||||||
SmallVector<int64_t> destinationIndices;
|
SmallVector<int64_t> destinationIndices;
|
||||||
if (failed(mapIndicesThroughHelperChain(
|
if (failed(mapIndicesThroughHelperChain(
|
||||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,7 +527,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
|||||||
auto scalarTensorType =
|
auto scalarTensorType =
|
||||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
|
||||||
rewriter.setInsertionPointAfter(elementSlice);
|
rewriter.setInsertionPointAfter(elementSlice);
|
||||||
|
|
||||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||||
@@ -513,7 +537,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
|||||||
elementSlice.getResult(),
|
elementSlice.getResult(),
|
||||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||||
0,
|
0,
|
||||||
static_cast<int32_t>(elementSize));
|
static_cast<int32_t>(elementSize),
|
||||||
|
constantFolder);
|
||||||
}
|
}
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
}
|
}
|
||||||
@@ -521,7 +546,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
|||||||
return ReturnPathLoweringResult::NotReturnPath;
|
return ReturnPathLoweringResult::NotReturnPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||||
|
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||||
|
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||||
if (!op)
|
if (!op)
|
||||||
return;
|
return;
|
||||||
@@ -538,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
|
|
||||||
if (isReturnHelperChainOp(op)) {
|
if (isReturnHelperChainOp(op)) {
|
||||||
Value source = op->getOperand(0);
|
Value source = op->getOperand(0);
|
||||||
markOpToRemove(state, op);
|
markOpToRemove(op);
|
||||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
markOpToRemove(state, computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (!computeOp.getInputs().empty())
|
if (!computeOp.getInputs().empty())
|
||||||
for (Value input : computeOp.getInputs())
|
for (Value input : computeOp.getInputs())
|
||||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||||
@@ -552,24 +582,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getOperands())
|
for (Value operand : concatOp.getOperands())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getInputs())
|
for (Value operand : concatOp.getInputs())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||||
markOpToRemove(state, concatOp);
|
markOpToRemove(concatOp);
|
||||||
for (Value operand : concatOp.getInputs())
|
for (Value operand : concatOp.getInputs())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||||
|
markOpToRemove(receiveOp);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||||
|
markOpToRemove(receiveTensorOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
@@ -578,7 +617,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
|||||||
size_t orderWithinReturn = it.index();
|
size_t orderWithinReturn = it.index();
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
|
||||||
|
|
||||||
struct ReturnPathState {
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
|
||||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum class ReturnPathLoweringResult {
|
|
||||||
Handled,
|
|
||||||
NotReturnPath,
|
|
||||||
Failure
|
|
||||||
};
|
|
||||||
|
|
||||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
|
||||||
mlir::IRRewriter& rewriter,
|
|
||||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
|
||||||
|
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
|
||||||
mlir::OpResult result,
|
|
||||||
mlir::Value yieldValue,
|
|
||||||
ReturnPathState& state,
|
|
||||||
mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVMM : Pat<
|
def spatToPimVMM : Pat<
|
||||||
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
(SpatVMMOp:$srcOpRes $weight, $vector),
|
||||||
(PimVMMOp $weightIndex, $vector,
|
(PimVMMOp $weight, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinDialect.h"
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
@@ -12,6 +13,7 @@
|
|||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
@@ -21,54 +23,28 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
#include "Pass/PIMPasses.h"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
#include "SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
using namespace pim;
|
using namespace pim;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
namespace raptor {
|
||||||
namespace {
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||||
|
|
||||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
} // namespace raptor
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
||||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
||||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
||||||
|
|
||||||
SpatialToPimPass() = default;
|
|
||||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
||||||
|
|
||||||
void runOnOperation() final;
|
|
||||||
|
|
||||||
private:
|
|
||||||
SmallVector<OutputTensorFactory> outputTensors;
|
|
||||||
size_t coreId = 0;
|
|
||||||
SmallVector<Operation*> operationsToRemove;
|
|
||||||
|
|
||||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
|
|
||||||
void markOpToRemove(Operation* op);
|
|
||||||
|
|
||||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||||
@@ -104,23 +80,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
|
|||||||
IntegerAttr {});
|
IntegerAttr {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
RankedTensorType tensorType,
|
||||||
|
OperationFolder& constantFolder) {
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
||||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||||
|
|
||||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||||
return PimMemCopyHostToDevBatchOp::create(
|
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
loc,
|
||||||
|
tensorType,
|
||||||
|
outputBuffer,
|
||||||
|
zeroValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
|
|
||||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
return PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
static Value
|
||||||
|
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||||
@@ -131,14 +118,16 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
|||||||
|
|
||||||
auto paddedType = RankedTensorType::get(
|
auto paddedType = RankedTensorType::get(
|
||||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 0;
|
coreId = 0;
|
||||||
|
outputTensors.clear();
|
||||||
|
operationsToRemove.clear();
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
|
||||||
@@ -151,6 +140,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
func::FuncOp funcOp = *entryFunc;
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
|
OperationFolder constantFolder(&getContext());
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<PimDialect,
|
target.addLegalDialect<PimDialect,
|
||||||
@@ -181,19 +171,18 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
RewritePatternSet globalTensorPatterns(ctx);
|
RewritePatternSet globalTensorPatterns(ctx);
|
||||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
||||||
|
|
||||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
addReturnOutputBuffers(returnOp, rewriter);
|
||||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
|
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -202,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
|
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||||
markOpToRemove(computeBatchOp);
|
markOpToRemove(computeBatchOp);
|
||||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -251,15 +240,8 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
eraseOpsToRemove();
|
||||||
|
|
||||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
|
||||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
|
||||||
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||||
@@ -301,7 +283,8 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
dumpModule(moduleOp, "pim0");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
@@ -309,7 +292,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||||
|
|
||||||
rewriter.setInsertionPoint(vmmOp);
|
rewriter.setInsertionPoint(vmmOp);
|
||||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||||
auto paddedOutputType = RankedTensorType::get(
|
auto paddedOutputType = RankedTensorType::get(
|
||||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||||
@@ -334,13 +317,17 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
if (!hasByteSizedElementType(elementType))
|
||||||
|
return;
|
||||||
|
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
|
||||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||||
|
|
||||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||||
@@ -349,10 +336,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
rewriter,
|
rewriter,
|
||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
|
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
||||||
|
getOrCreateHostIndexConstant(
|
||||||
|
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
inputTensor,
|
inputTensor,
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||||
|
|
||||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||||
@@ -374,11 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||||
if (!llvm::is_contained(operationsToRemove, op))
|
if (!llvm::is_contained(operationsToRemove, op))
|
||||||
operationsToRemove.push_back(op);
|
operationsToRemove.push_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||||
|
for (Operation* op : operationsToRemove) {
|
||||||
|
op->dropAllUses();
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace raptor {
|
||||||
|
|
||||||
|
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||||
|
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||||
|
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||||
|
|
||||||
|
SpatialToPimPass() = default;
|
||||||
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||||
|
|
||||||
|
void runOnOperation() final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||||
|
|
||||||
|
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||||
|
size_t coreId = 0;
|
||||||
|
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||||
|
|
||||||
|
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
|
mlir::LogicalResult
|
||||||
|
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
||||||
|
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
enum class ReturnPathLoweringResult {
|
||||||
|
Handled,
|
||||||
|
NotReturnPath,
|
||||||
|
Failure
|
||||||
|
};
|
||||||
|
|
||||||
|
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
|
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||||
|
mlir::OpResult result,
|
||||||
|
mlir::Value yieldValue,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||||
|
mlir::Value producedValue,
|
||||||
|
mlir::Value storedValue,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
|
void markOpToRemove(mlir::Operation* op);
|
||||||
|
void eraseOpsToRemove();
|
||||||
|
|
||||||
|
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace raptor
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+34
-16
@@ -2,6 +2,7 @@
|
|||||||
#define PIM_DIALECT_H
|
#define PIM_DIALECT_H
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
@@ -24,7 +25,8 @@ def PimTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
def PimCoreOp : PimOp<"core", [SingleBlock,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Execute a block on a PIM core";
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
|||||||
I32Attr:$coreId
|
I32Attr:$coreId
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let extraClassDeclaration = [{
|
||||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
|
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Execute equivalent batched core bodies";
|
let summary = "Execute equivalent batched core bodies";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
|
|||||||
Variadic<PimTensor>:$inputs
|
Variadic<PimTensor>:$inputs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
::mlir::BlockArgument getLaneArgument();
|
||||||
|
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||||
|
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
PimTensor:$input,
|
PimTensor:$input,
|
||||||
I32Attr:$size,
|
I32Attr:$size,
|
||||||
I32Attr:$targetCoreId
|
Index:$targetCoreId
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
|||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
PimTensor:$outputBuffer,
|
PimTensor:$outputBuffer,
|
||||||
I32Attr:$size,
|
I32Attr:$size,
|
||||||
I32Attr:$sourceCoreId
|
Index:$sourceCoreId
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|||||||
let summary = "Copy a memory region from host memory into device memory";
|
let summary = "Copy a memory region from host memory into device memory";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
Index:$deviceTargetOffset,
|
||||||
|
Index:$hostSourceOffset,
|
||||||
PimTensor:$deviceTarget,
|
PimTensor:$deviceTarget,
|
||||||
PimTensor:$hostSource,
|
PimTensor:$hostSource,
|
||||||
I32Attr:$deviceTargetOffset,
|
|
||||||
I32Attr:$hostSourceOffset,
|
|
||||||
I32Attr:$size
|
I32Attr:$size
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
|
||||||
|
`(` $deviceTarget `,` $hostSource `)` attr-dict
|
||||||
|
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
|||||||
let summary = "Copy a memory region from device memory into host memory";
|
let summary = "Copy a memory region from device memory into host memory";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
Index:$hostTargetOffset,
|
||||||
|
Index:$deviceSourceOffset,
|
||||||
PimTensor:$hostTarget,
|
PimTensor:$hostTarget,
|
||||||
PimTensor:$deviceSource,
|
PimTensor:$deviceSource,
|
||||||
I32Attr:$hostTargetOffset,
|
|
||||||
I32Attr:$deviceSourceOffset,
|
|
||||||
I32Attr:$size
|
I32Attr:$size
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
|
||||||
|
`(` $hostTarget `,` $deviceSource `)` attr-dict
|
||||||
|
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|||||||
let summary = "Vector-matrix multiplication: c = a * b";
|
let summary = "Vector-matrix multiplication: c = a * b";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$weightIndex,
|
PimTensor:$weight,
|
||||||
PimTensor:$input,
|
PimTensor:$input,
|
||||||
PimTensor:$outputBuffer
|
PimTensor:$outputBuffer
|
||||||
);
|
);
|
||||||
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
|
||||||
|
type($outputBuffer) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,41 @@
|
|||||||
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
|
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||||
|
|
||||||
|
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
|
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||||
|
|
||||||
|
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||||
|
|
||||||
|
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
|
||||||
|
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
setNameFn(getLaneArgument(), "lane");
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
|
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||||
|
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
void PimDialect::initialize() {
|
void PimDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
|||||||
@@ -20,6 +20,79 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
|
|||||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||||
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||||
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||||
|
printer << "(";
|
||||||
|
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(argument);
|
||||||
|
}
|
||||||
|
printer << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalRParen()))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
if (parser.parseArgument(argument))
|
||||||
|
return failure();
|
||||||
|
arguments.push_back(argument);
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseArgument(argument))
|
||||||
|
return failure();
|
||||||
|
arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
return parser.parseRParen();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||||
|
printCompressedValueList(printer, arguments, delimiter);
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||||
|
ListDelimiter delimiter,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||||
|
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||||
|
switch (currentDelimiter) {
|
||||||
|
case ListDelimiter::Paren: return parser.parseRParen();
|
||||||
|
case ListDelimiter::Square: return parser.parseRSquare();
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported delimiter");
|
||||||
|
};
|
||||||
|
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
||||||
printer << " " << keyword << " ";
|
printer << " " << keyword << " ";
|
||||||
printCompressedIntegerList(printer, coreIds);
|
printCompressedIntegerList(printer, coreIds);
|
||||||
@@ -33,15 +106,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||||
printer << " lanes " << getLaneCount() << " ";
|
SmallVector<Value> weightArgs;
|
||||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
weightArgs.reserve(getWeights().size());
|
||||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
|
weightArgs.push_back(getWeightArgument(index));
|
||||||
else
|
|
||||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
|
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " coreId " << getCoreId();
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||||
|
printer << " -> () ";
|
||||||
|
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
|
SmallVector<Type> weightTypes;
|
||||||
|
int32_t coreId = 0;
|
||||||
|
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
|
if (hasCoreId && parser.parseInteger(coreId))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (weights.size() != weightTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
|
if (hasCoreId && result.attributes.get("coreId"))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreId cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (hasCoreId)
|
||||||
|
result.addAttribute("coreId", getI32Attr(parser, coreId));
|
||||||
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
return parser.parseRegion(*body, weightArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||||
|
printer << " ";
|
||||||
|
printer.printOperand(getLaneArgument());
|
||||||
|
printer << " = 0 to " << getLaneCount() << " ";
|
||||||
|
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(getWeights().size());
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
|
weightArgs.push_back(getWeightArgument(index));
|
||||||
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(getInputs().size());
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||||
|
inputArgs.push_back(getInputArgument(index));
|
||||||
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||||
|
|
||||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
||||||
@@ -49,51 +183,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
|||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(
|
||||||
(*this)->getAttrs(),
|
(*this)->getAttrs(),
|
||||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||||
printer << " ";
|
|
||||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
|
|
||||||
else
|
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
|
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||||
printer << " -> () ";
|
printer << " -> () ";
|
||||||
|
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
int64_t lowerBound = 0;
|
||||||
int32_t laneCount = 0;
|
int32_t laneCount = 0;
|
||||||
|
OpAsmParser::Argument laneArg;
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
SmallVector<Type> weightTypes;
|
SmallVector<Type> weightTypes;
|
||||||
SmallVector<Type> inputTypes;
|
SmallVector<Type> inputTypes;
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
|
|
||||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||||
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||||
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
|
return failure();
|
||||||
|
if (lowerBound != 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
|
||||||
|
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|
||||||
|
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes))
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
return failure();
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
Region* body = result.addRegion();
|
|| parseCompressedRepeatedList(
|
||||||
if (parser.parseRegion(*body))
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
return failure();
|
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||||
|
|
||||||
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|
|
||||||
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|
|
||||||
|| parser.parseLParen() || parser.parseRParen())
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
if (weights.size() != weightTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
if (inputs.size() != inputTypes.size())
|
if (inputs.size() != inputTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (inputArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
|
||||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
"coreIds cannot be specified both positionally and in attr-dict");
|
"coreIds cannot be specified both positionally and in attr-dict");
|
||||||
@@ -110,7 +250,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
return success();
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
laneArg.type = builder.getIndexType();
|
||||||
|
regionArgs.push_back(laneArg);
|
||||||
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
llvm::append_range(regionArgs, weightArgs);
|
||||||
|
applyArgumentTypes(inputTypes, inputArgs);
|
||||||
|
llvm::append_range(regionArgs, inputArgs);
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
@@ -14,6 +19,63 @@ namespace pim {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
|
if (isa<PimMemCopyHostToDevOp>(op))
|
||||||
|
return operandIndex == 3;
|
||||||
|
if (isa<PimMemCopyHostToDevBatchOp>(op))
|
||||||
|
return operandIndex == 1;
|
||||||
|
if (isa<PimMemCopyDevToHostOp>(op))
|
||||||
|
return operandIndex == 2;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Region* getParentRegion(Value value) {
|
||||||
|
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||||
|
return blockArgument.getParentRegion();
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
return definingOp ? definingOp->getParentRegion() : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||||
|
Region* parentRegion = getParentRegion(value);
|
||||||
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isConstantExternalValue(Value value) {
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return false;
|
||||||
|
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
return globalOp && globalOp.getConstant();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
region.walk([&](Operation* op) {
|
||||||
|
for (OpOperand& operand : op->getOpOperands()) {
|
||||||
|
Value value = operand.get();
|
||||||
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|
||||||
|
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||||
|
<< kind << " body may only directly reference external constants";
|
||||||
|
diagnostic.attachNote(op->getLoc())
|
||||||
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||||
}
|
}
|
||||||
@@ -78,24 +140,43 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||||
if (weightIndex >= coreOp.getWeights().size())
|
if (!shapedType)
|
||||||
return failure();
|
|
||||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
|
||||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
|
||||||
return failure();
|
|
||||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
|
return shapedType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult PimCoreOp::verify() {
|
||||||
|
Block& block = getBody().front();
|
||||||
|
if (block.getNumArguments() != getWeights().size())
|
||||||
|
return emitError("core body must have one block argument per weight");
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
|
return emitError("core weight block argument types must match weight operand types exactly");
|
||||||
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult PimCoreBatchOp::verify() {
|
||||||
|
if (getLaneCount() <= 0)
|
||||||
|
return emitError("laneCount must be positive");
|
||||||
|
Block& block = getBody().front();
|
||||||
|
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
|
||||||
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
|
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||||
|
if (!getLaneArgument().getType().isIndex())
|
||||||
|
return emitError("core_batch first block argument must have index type");
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
|
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||||
|
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||||
|
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||||
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult PimSendTensorOp::verify() {
|
LogicalResult PimSendTensorOp::verify() {
|
||||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||||
}
|
}
|
||||||
@@ -126,9 +207,9 @@ LogicalResult PimVMMOp::verify() {
|
|||||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
return emitError("weight must be a shaped value");
|
||||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||||
|
|
||||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
|||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||||
|
|
||||||
return PimMemCopyOp::create(rewriter,
|
return PimMemCopyOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||||
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
|
||||||
return builder.getI32IntegerAttr(sizeInBytes);
|
return builder.getI32IntegerAttr(sizeInBytes);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
|
|||||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||||
memCopyHostToDevOp,
|
memCopyHostToDevOp,
|
||||||
deviceTargetMemRef.getType(),
|
deviceTargetMemRef.getType(),
|
||||||
|
memCopyHostToDevOp.getDeviceTargetOffset(),
|
||||||
|
memCopyHostToDevOp.getHostSourceOffset(),
|
||||||
deviceTargetMemRef,
|
deviceTargetMemRef,
|
||||||
hostSourceMemRef,
|
hostSourceMemRef,
|
||||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getSizeAttr());
|
memCopyHostToDevOp.getSizeAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
|
|||||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||||
memCopyDevToHostOp,
|
memCopyDevToHostOp,
|
||||||
hostTargetMemRef.getType(),
|
hostTargetMemRef.getType(),
|
||||||
|
memCopyDevToHostOp.getHostTargetOffset(),
|
||||||
|
memCopyDevToHostOp.getDeviceSourceOffset(),
|
||||||
hostTargetMemRef,
|
hostTargetMemRef,
|
||||||
deviceSourceMemRef,
|
deviceSourceMemRef,
|
||||||
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
|
||||||
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
|
||||||
memCopyDevToHostOp.getSizeAttr());
|
memCopyDevToHostOp.getSizeAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||||
op,
|
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||||
outputBufferOpt->getType(),
|
|
||||||
*outputBufferOpt,
|
|
||||||
receiveOp.getSizeAttr(),
|
|
||||||
receiveOp.getSourceCoreIdAttr());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
|||||||
op,
|
op,
|
||||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||||
sendOp.getSizeAttr(),
|
sendOp.getSizeAttr(),
|
||||||
sendOp.getTargetCoreIdAttr());
|
sendOp.getTargetCoreId());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||||
|
auto coreOp = cast<PimCoreOp>(op);
|
||||||
|
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||||
|
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
unsigned weightIndex = bbArg.getArgNumber();
|
||||||
|
return {
|
||||||
|
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||||
|
|
||||||
|
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||||
|
Value value,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
const BufferizationState& state,
|
||||||
|
SmallVector<Value>& invocationStack) const {
|
||||||
|
auto coreOp = cast<PimCoreOp>(op);
|
||||||
|
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||||
|
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
|
||||||
|
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
|
||||||
|
return memRefType;
|
||||||
|
|
||||||
|
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
RewriterBase& rewriter,
|
RewriterBase& rewriter,
|
||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
|||||||
auto coreOp = cast<PimCoreOp>(op);
|
auto coreOp = cast<PimCoreOp>(op);
|
||||||
|
|
||||||
bool alreadyBufferized =
|
bool alreadyBufferized =
|
||||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
|
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||||
|
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||||
|
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||||
|
});
|
||||||
if (alreadyBufferized)
|
if (alreadyBufferized)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
unsigned argNumber = bbArg.getArgNumber();
|
||||||
|
if (argNumber == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||||
|
unsigned operandIndex = argNumber - 1;
|
||||||
|
if (argNumber > weightCount + 1)
|
||||||
|
operandIndex = weightCount + (argNumber - 1 - weightCount);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
unsigned argNumber = bbArg.getArgNumber();
|
||||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
if (argNumber == 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value tiedOperand;
|
||||||
|
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||||
|
if (argNumber <= weightCount)
|
||||||
|
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
|
||||||
|
else
|
||||||
|
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
|
||||||
|
|
||||||
|
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
|
||||||
return memRefType;
|
return memRefType;
|
||||||
|
|
||||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
bool alreadyBufferized =
|
bool alreadyBufferized =
|
||||||
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||||
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||||
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||||
|
});
|
||||||
if (alreadyBufferized)
|
if (alreadyBufferized)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto vmmOp = cast<PimVMMOp>(op);
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
|
|
||||||
|
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
|
||||||
|
if (failed(weightOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool isCandidateAllocType(MemRefType type) {
|
static bool isCandidateAllocType(MemRefType type) {
|
||||||
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
|
return type && type.hasStaticShape() && type.getLayout().isIdentity()
|
||||||
|
&& hasByteSizedElementType(type.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint64_t getTypeSizeBytes(MemRefType type) {
|
static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<uint64_t>
|
static FailureOr<uint64_t>
|
||||||
@@ -50,11 +52,10 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
|||||||
pendingValues.push_back(result);
|
pendingValues.push_back(result);
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
|
||||||
if (initArg == value)
|
if (initArg == value)
|
||||||
pendingValues.push_back(forOp.getResult(index));
|
pendingValues.push_back(forOp.getResult(index));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||||
for (OpResult result : user->getResults()) {
|
for (OpResult result : user->getResults()) {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||||
|
|
||||||
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return constantValue.getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||||
|
return getConstantI64(sendOp.getChannelId());
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||||
|
return getConstantI64(receiveOp.getChannelId());
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||||
|
return getConstantI32(receiveOp.getSourceCoreId());
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||||
|
return getConstantI32(receiveOp.getTargetCoreId());
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||||
if (!endpoints.send || !endpoints.receive)
|
if (!endpoints.send || !endpoints.receive)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||||
|
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||||
|
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||||
|
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
|
||||||
|
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||||
|
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||||
|
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||||
|
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
|
|||||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||||
|
|
||||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||||
ChannelId channelId = getChannelId(sendOp);
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
if (failed(channelId))
|
||||||
endpoints[channelId].send = sendOp;
|
return;
|
||||||
|
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||||
|
endpoints[*channelId].send = sendOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||||
ChannelId channelId = getChannelId(receiveOp);
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
if (failed(channelId))
|
||||||
endpoints[channelId].receive = receiveOp;
|
return;
|
||||||
|
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||||
|
endpoints[*channelId].receive = receiveOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||||
ChannelId channelId = getChannelId(sendOp);
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||||
auto it = endpoints.find(channelId);
|
if (failed(channelId))
|
||||||
|
return;
|
||||||
|
auto it = endpoints.find(*channelId);
|
||||||
if (it == endpoints.end())
|
if (it == endpoints.end())
|
||||||
return;
|
return;
|
||||||
it->second.send = {};
|
it->second.send = {};
|
||||||
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||||
ChannelId channelId = getChannelId(receiveOp);
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||||
auto it = endpoints.find(channelId);
|
if (failed(channelId))
|
||||||
|
return;
|
||||||
|
auto it = endpoints.find(*channelId);
|
||||||
if (it == endpoints.end())
|
if (it == endpoints.end())
|
||||||
return;
|
return;
|
||||||
it->second.receive = {};
|
it->second.receive = {};
|
||||||
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||||
|
if (failed(channelId))
|
||||||
|
return failure();
|
||||||
|
auto endpointsOr = lookup(*channelId);
|
||||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||||
return failure();
|
return failure();
|
||||||
return endpointsOr->receive;
|
return endpointsOr->receive;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||||
|
if (failed(channelId))
|
||||||
|
return failure();
|
||||||
|
auto endpointsOr = lookup(*channelId);
|
||||||
if (failed(endpointsOr) || !endpointsOr->send)
|
if (failed(endpointsOr) || !endpointsOr->send)
|
||||||
return failure();
|
return failure();
|
||||||
return endpointsOr->send;
|
return endpointsOr->send;
|
||||||
|
|||||||
@@ -2,8 +2,12 @@
|
|||||||
#define SPATIAL_DIALECT_H
|
#define SPATIAL_DIALECT_H
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir/IR/OpAsmInterface.td"
|
||||||
include "mlir/IR/BuiltinTypes.td"
|
include "mlir/IR/BuiltinTypes.td"
|
||||||
include "mlir/IR/AttrTypeBase.td"
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/IR/RegionKindInterface.td"
|
||||||
|
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
def SpatialDialect : Dialect {
|
def SpatialDialect : Dialect {
|
||||||
let name = "spat";
|
let name = "spat";
|
||||||
@@ -22,7 +26,9 @@ def SpatTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
def SpatCompute : SpatOp<"compute",
|
||||||
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Compute region with attached constant weights";
|
let summary = "Compute region with attached constant weights";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@@ -36,14 +42,26 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||||
[SingleBlock, AttrSizedOperandSegments]> {
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
|
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$laneCount,
|
I32Attr:$laneCount,
|
||||||
@@ -57,10 +75,47 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||||
|
Pure,
|
||||||
|
Terminator,
|
||||||
|
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||||
|
HasParent<"SpatComputeBatch">,
|
||||||
|
] # GraphRegionNoTerminator.traits> {
|
||||||
|
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins)>,
|
||||||
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
||||||
|
::mlir::OpResult getParentResult(int64_t idx);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||||
let summary = "Yield results from a compute region";
|
let summary = "Yield results from a compute region";
|
||||||
|
|
||||||
@@ -110,14 +165,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
|||||||
let summary = "Send a tensor through a logical channel";
|
let summary = "Send a tensor through a logical channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I64Attr:$channelId,
|
Index:$channelId,
|
||||||
I32Attr:$sourceCoreId,
|
Index:$sourceCoreId,
|
||||||
I32Attr:$targetCoreId,
|
Index:$targetCoreId,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$input attr-dict `:` type($input)
|
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,9 +180,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|||||||
let summary = "Receive a tensor from a logical channel";
|
let summary = "Receive a tensor from a logical channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I64Attr:$channelId,
|
Index:$channelId,
|
||||||
I32Attr:$sourceCoreId,
|
Index:$sourceCoreId,
|
||||||
I32Attr:$targetCoreId
|
Index:$targetCoreId
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -135,31 +190,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
attr-dict `:` type($output)
|
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds,
|
Variadic<Index>:$targetCoreIds,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
Variadic<Index>:$targetCoreIds
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -167,44 +224,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds,
|
Variadic<Index>:$targetCoreIds,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds,
|
Variadic<Index>:$targetCoreIds,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
Variadic<Index>:$targetCoreIds
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -212,16 +275,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
||||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
Variadic<Index>:$channelIds,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
Variadic<Index>:$sourceCoreIds,
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
Variadic<Index>:$targetCoreIds
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -229,7 +294,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
|
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -240,7 +307,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
|||||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$weightIndex,
|
SpatTensor:$weight,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -251,7 +318,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +326,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
|||||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I32Attr:$weightIndex,
|
SpatTensor:$weight,
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -270,7 +337,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,208 @@
|
|||||||
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
|
||||||
|
if (body.empty())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
Block& block = body.front();
|
||||||
|
if (argIdx >= block.getNumArguments())
|
||||||
|
return std::nullopt;
|
||||||
|
return block.getArgument(argIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||||
|
if (body.empty())
|
||||||
|
return std::nullopt;
|
||||||
|
return body.insertArgument(argIdx, type, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||||
|
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
|
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
|
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||||
|
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(getOperation());
|
||||||
|
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
||||||
|
newCompute->setAttrs((*this)->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||||
|
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||||
|
getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
|
if (auto weightArg = getWeightArgument(index))
|
||||||
|
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||||
|
if (auto inputArg = getInputArgument(index))
|
||||||
|
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), 1 + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
|
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
|
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
|
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||||
|
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(getOperation());
|
||||||
|
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newBatch =
|
||||||
|
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
||||||
|
newBatch->setAttrs((*this)->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||||
|
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||||
|
if (newBatch.getBody().empty()) {
|
||||||
|
rewriter.eraseOp(newBatch);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||||
|
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||||
|
getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (auto laneArg = getLaneArgument())
|
||||||
|
setNameFn(*laneArg, "lane");
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||||
|
if (auto weightArg = getWeightArgument(index))
|
||||||
|
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||||
|
if (auto inputArg = getInputArgument(index))
|
||||||
|
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
|
auto outputArg = getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
continue;
|
||||||
|
if (index == 0) {
|
||||||
|
setNameFn(*outputArg, "out");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
Region* bodyRegion = result.addRegion();
|
||||||
|
builder.createBlock(bodyRegion);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||||
|
|
||||||
|
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||||
|
|
||||||
void SpatialDialect::initialize() {
|
void SpatialDialect::initialize() {
|
||||||
addTypes<
|
addTypes<
|
||||||
|
|||||||
@@ -5,10 +5,15 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/RegionKindInterface.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
/// Include the auto-generated header files containing the declarations
|
/// Include the auto-generated header files containing the declarations
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
||||||
|
|||||||
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
|||||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
|
||||||
ArrayRef<int64_t> channelIds,
|
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
|
||||||
ArrayRef<int32_t> targetCoreIds) {
|
|
||||||
printer << " channels ";
|
|
||||||
printCompressedIntegerList(printer, channelIds);
|
|
||||||
printer << " from ";
|
|
||||||
printCompressedIntegerList(printer, sourceCoreIds);
|
|
||||||
printer << " to ";
|
|
||||||
printCompressedIntegerList(printer, targetCoreIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
|
|
||||||
return parser.getBuilder().getDenseI64ArrayAttr(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||||
}
|
}
|
||||||
@@ -47,94 +31,86 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|||||||
return parser.getBuilder().getI32IntegerAttr(value);
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorSendOpTy>
|
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
printer << "(";
|
||||||
printer << " ";
|
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||||
printer.printOperand(op.getInput());
|
if (index != 0)
|
||||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
printer << ", ";
|
||||||
printer.printOptionalAttrDict(op->getAttrs(),
|
printer.printOperand(argument);
|
||||||
{op.getChannelIdsAttrName().getValue(),
|
}
|
||||||
op.getSourceCoreIdsAttrName().getValue(),
|
printer << ")";
|
||||||
op.getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(op.getInput().getType());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorReceiveOpTy>
|
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
if (parser.parseLParen())
|
||||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
|
||||||
printer.printOptionalAttrDict(op->getAttrs(),
|
|
||||||
{op.getChannelIdsAttrName().getValue(),
|
|
||||||
op.getSourceCoreIdsAttrName().getValue(),
|
|
||||||
op.getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(op.getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
|
||||||
OpAsmParser::UnresolvedOperand input;
|
|
||||||
Type inputType;
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
if (parser.parseOperand(input))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalRParen()))
|
||||||
|
return success();
|
||||||
|
|
||||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
OpAsmParser::Argument argument;
|
||||||
if (hasMetadata) {
|
if (parser.parseArgument(argument))
|
||||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
|
||||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
|
||||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
arguments.push_back(argument);
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
if (parser.parseArgument(argument))
|
||||||
return failure();
|
return failure();
|
||||||
|
arguments.push_back(argument);
|
||||||
if (hasMetadata
|
}
|
||||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
return parser.parseRParen();
|
||||||
|| result.attributes.get("targetCoreIds")))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
|
||||||
if (hasMetadata) {
|
|
||||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return parser.resolveOperand(input, inputType, result.operands);
|
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
|
||||||
|
ArrayRef<Type> weightTypes,
|
||||||
|
ArrayRef<Type> outputTypes,
|
||||||
|
OpAsmParser::Argument& laneArg,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
|
||||||
|
Builder& builder) {
|
||||||
|
laneArg.type = builder.getIndexType();
|
||||||
|
regionArgs.push_back(laneArg);
|
||||||
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
llvm::append_range(regionArgs, weightArgs);
|
||||||
|
applyArgumentTypes(inputTypes, inputArgs);
|
||||||
|
applyArgumentTypes(outputTypes, outputArgs);
|
||||||
|
llvm::append_range(regionArgs, inputArgs);
|
||||||
|
llvm::append_range(regionArgs, outputArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
static void
|
||||||
Type outputType;
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||||
SmallVector<int64_t> channelIds;
|
printCompressedValueList(printer, arguments, delimiter);
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
printer << " = ";
|
||||||
SmallVector<int32_t> targetCoreIds;
|
printCompressedValueList(printer, operands, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||||
if (hasMetadata) {
|
ListDelimiter delimiter,
|
||||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||||
|
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
return success();
|
||||||
if (hasMetadata
|
|
||||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
|
||||||
|| result.attributes.get("targetCoreIds")))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
|
||||||
if (hasMetadata) {
|
|
||||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result.addTypes(outputType);
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||||
|
switch (currentDelimiter) {
|
||||||
|
case ListDelimiter::Paren: return parser.parseRParen();
|
||||||
|
case ListDelimiter::Square: return parser.parseRSquare();
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported delimiter");
|
||||||
|
};
|
||||||
|
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||||
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,10 +218,27 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(getWeights().size());
|
||||||
|
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||||
|
auto weightArg = getWeightArgument(index);
|
||||||
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(getInputs().size());
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||||
|
auto inputArg = getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
printer << " coreId " << coreIdAttr.getInt();
|
printer << " coreId " << coreIdAttr.getInt();
|
||||||
@@ -264,6 +257,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
@@ -272,10 +266,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
SmallVector<Type> outputTypes;
|
SmallVector<Type> outputTypes;
|
||||||
int32_t coreId = 0;
|
int32_t coreId = 0;
|
||||||
|
|
||||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
@@ -292,9 +287,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
if (weights.size() != weightTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
if (inputs.size() != inputTypes.size())
|
if (inputs.size() != inputTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
if (regionArgs.size() != inputs.size())
|
if (inputArgs.size() != inputs.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
@@ -313,19 +310,58 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
result.addTypes(outputTypes);
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
Region* body = result.addRegion();
|
||||||
applyArgumentTypes(inputTypes, regionArgs);
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
applyArgumentTypes(inputTypes, inputArgs);
|
||||||
|
llvm::append_range(regionArgs, weightArgs);
|
||||||
|
llvm::append_range(regionArgs, inputArgs);
|
||||||
return parser.parseRegion(*body, regionArgs);
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||||
printer << " lanes " << getLaneCount() << " ";
|
auto laneArg = getLaneArgument();
|
||||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
SmallVector<Value> weightArgs;
|
||||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
weightArgs.reserve(getWeights().size());
|
||||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||||
else
|
auto weightArg = getWeightArgument(index);
|
||||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(getInputs().size());
|
||||||
|
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||||
|
auto inputArg = getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<BlockArgument> outputArgs;
|
||||||
|
if (!laneArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
if (getNumResults() != 0) {
|
||||||
|
outputArgs.reserve(getNumResults());
|
||||||
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
|
auto outputArg = getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||||
|
outputArgs.push_back(*outputArg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
printer.printOperand(*laneArg);
|
||||||
|
printer << " = 0 to " << getLaneCount();
|
||||||
|
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||||
|
|
||||||
|
if (getNumResults() != 0) {
|
||||||
|
printer << " shared_outs";
|
||||||
|
printBlockArgumentList(printer, outputArgs);
|
||||||
|
}
|
||||||
|
|
||||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||||
printer << " coreIds ";
|
printer << " coreIds ";
|
||||||
@@ -337,9 +373,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
|||||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||||
|
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
|
||||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
|
||||||
else
|
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||||
@@ -350,7 +383,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
int64_t lowerBound = 0;
|
||||||
int32_t laneCount = 0;
|
int32_t laneCount = 0;
|
||||||
|
OpAsmParser::Argument laneArg;
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
@@ -359,13 +397,20 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
SmallVector<Type> outputTypes;
|
SmallVector<Type> outputTypes;
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
|
|
||||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||||
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||||
|
return failure();
|
||||||
|
if (lowerBound != 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||||
|
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||||
|
if (parseBlockArgumentList(parser, outputArgs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
@@ -381,10 +426,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
if (weights.size() != weightTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
if (inputs.size() != inputTypes.size())
|
if (inputs.size() != inputTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
if (regionArgs.size() != inputs.size())
|
if (inputArgs.size() != inputs.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (outputArgs.size() != outputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"number of shared output bindings and result types must match");
|
||||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
"coreIds cannot be specified both positionally and in attr-dict");
|
"coreIds cannot be specified both positionally and in attr-dict");
|
||||||
@@ -403,119 +453,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
result.addTypes(outputTypes);
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
Region* body = result.addRegion();
|
||||||
applyArgumentTypes(inputTypes, regionArgs);
|
applyBatchRegionArgumentTypes(
|
||||||
|
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||||
return parser.parseRegion(*body, regionArgs);
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||||
|
|
||||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
return parseTensorSendOp(parser, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printer.printOperand(getInput());
|
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
|
||||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||||
printer.printOptionalAttrDict(
|
|
||||||
(*this)->getAttrs(),
|
|
||||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getInput().getType());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
OpAsmParser::UnresolvedOperand input;
|
auto& builder = parser.getBuilder();
|
||||||
Type inputType;
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
if (parser.parseRegion(*region, regionArgs))
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
if (parser.parseOperand(input))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
if (region->empty())
|
||||||
if (hasMetadata) {
|
OpBuilder(builder.getContext()).createBlock(region.get());
|
||||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
result.addRegion(std::move(region));
|
||||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
return parser.parseOptionalAttrDict(result.attributes);
|
||||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (hasMetadata
|
|
||||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
|
||||||
|| result.attributes.get("targetCoreIds")))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
|
||||||
if (hasMetadata) {
|
|
||||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
}
|
|
||||||
|
|
||||||
return parser.resolveOperand(input, inputType, result.operands);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
|
||||||
|
|
||||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
return parseTensorSendOp(parser, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
|
||||||
|
|
||||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
return parseTensorReceiveOp(parser, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
|
||||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
|
||||||
printer.printOptionalAttrDict(
|
|
||||||
(*this)->getAttrs(),
|
|
||||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
|
||||||
printer << " : ";
|
|
||||||
printer.printType(getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
Type outputType;
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
|
|
||||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
|
||||||
if (hasMetadata) {
|
|
||||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
|
||||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
|
||||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (hasMetadata
|
|
||||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
|
||||||
|| result.attributes.get("targetCoreIds")))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
|
||||||
if (hasMetadata) {
|
|
||||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
|
||||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
|
||||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
|
||||||
}
|
|
||||||
|
|
||||||
result.addTypes(outputType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
|
||||||
|
|
||||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
return parseTensorReceiveOp(parser, result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
@@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
if (!shapedType)
|
||||||
|
|
||||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
|
||||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
|
||||||
|
|
||||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
|
||||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
|
||||||
return failure();
|
|
||||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
|
return shapedType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||||
@@ -105,15 +99,98 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
|||||||
return batchOp.getLaneCount();
|
return batchOp.getLaneCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||||
Type type,
|
if (batchOp.getNumResults() == 0)
|
||||||
ArrayRef<int64_t> channelIds,
|
return false;
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||||
ArrayRef<int32_t> targetCoreIds,
|
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
unsigned argNumber = blockArg.getArgNumber();
|
||||||
|
auto firstOutputArg = batchOp.getOutputArgument(0);
|
||||||
|
if (!firstOutputArg)
|
||||||
|
return false;
|
||||||
|
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
|
||||||
|
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isConstantIndexLike(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||||
|
if (value == laneArg || isConstantIndexLike(value))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
|
||||||
|
if (extractOp) {
|
||||||
|
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
|
||||||
|
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
|
||||||
|
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
|
||||||
|
return false;
|
||||||
|
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||||
|
if (!addOp)
|
||||||
|
return false;
|
||||||
|
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
||||||
|
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||||
|
if (!sliceOp.hasUnitStride())
|
||||||
|
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||||
|
|
||||||
|
for (int64_t size : sliceOp.getStaticSizes())
|
||||||
|
if (ShapedType::isDynamic(size))
|
||||||
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||||
|
|
||||||
|
auto offsets = sliceOp.getOffsets();
|
||||||
|
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||||
|
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||||
|
if (!supported)
|
||||||
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
||||||
|
BlockArgument laneArg,
|
||||||
StringRef kind) {
|
StringRef kind) {
|
||||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
RankedTensorType sourceType = sliceOp.getSourceType();
|
||||||
|
RankedTensorType destType = sliceOp.getDestType();
|
||||||
|
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
||||||
|
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||||
|
if (!sliceOp.hasUnitStride())
|
||||||
|
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||||
|
|
||||||
|
for (int64_t size : sliceOp.getStaticSizes())
|
||||||
|
if (ShapedType::isDynamic(size))
|
||||||
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||||
|
|
||||||
|
auto offsets = sliceOp.getOffsets();
|
||||||
|
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||||
|
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||||
|
if (!supported)
|
||||||
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyTensorChannelSizes(
|
||||||
|
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||||
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||||
if (channelIds.empty())
|
if (channelCount == 0)
|
||||||
return op->emitError() << kind << " must carry at least one chunk";
|
return op->emitError() << kind << " must carry at least one chunk";
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
@@ -125,40 +202,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
|
|||||||
return op->emitError() << kind << " requires byte-sized elements";
|
return op->emitError() << kind << " requires byte-sized elements";
|
||||||
|
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
static LogicalResult
|
||||||
ArrayRef<int64_t> channelIds,
|
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||||
ArrayRef<int32_t> targetCoreIds) {
|
|
||||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||||
|
|
||||||
auto laneCount = getParentBatchLaneCount(op);
|
auto laneCount = getParentBatchLaneCount(op);
|
||||||
if (failed(laneCount))
|
if (failed(laneCount))
|
||||||
return op->emitError("must be nested inside spat.compute_batch");
|
return op->emitError("must be nested inside spat.compute_batch");
|
||||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
if (channelCount != static_cast<size_t>(*laneCount))
|
||||||
return op->emitError("channel metadata length must match parent laneCount");
|
return op->emitError("channel metadata length must match parent laneCount");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
static LogicalResult verifyTensorBatchChannelSizes(
|
||||||
Type type,
|
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||||
ArrayRef<int64_t> channelIds,
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||||
ArrayRef<int32_t> sourceCoreIds,
|
|
||||||
ArrayRef<int32_t> targetCoreIds,
|
|
||||||
StringRef kind) {
|
|
||||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
|
||||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||||
|
|
||||||
auto laneCount = getParentBatchLaneCount(op);
|
auto laneCount = getParentBatchLaneCount(op);
|
||||||
if (failed(laneCount))
|
if (failed(laneCount))
|
||||||
return op->emitError("must be nested inside spat.compute_batch");
|
return op->emitError("must be nested inside spat.compute_batch");
|
||||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
@@ -169,7 +240,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
|||||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||||
return op->emitError() << kind << " requires byte-sized elements";
|
return op->emitError() << kind << " requires byte-sized elements";
|
||||||
|
|
||||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||||
if (totalBytes % chunkCount != 0)
|
if (totalBytes % chunkCount != 0)
|
||||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||||
@@ -177,28 +248,61 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
static Region* getParentRegion(Value value) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
if (!yieldOp)
|
return blockArg.getOwner()->getParent();
|
||||||
return op->emitError("body must terminate with spat.yield");
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
if (outputTypes.empty()) {
|
return definingOp->getParentRegion();
|
||||||
if (yieldOp.getNumOperands() != 0)
|
return nullptr;
|
||||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (yieldOp.getNumOperands() != 1)
|
|
||||||
return op->emitError("body yield must produce exactly one value");
|
|
||||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
|
||||||
return op->emitError("body yield type must match output type");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||||
|
Region* parentRegion = getParentRegion(value);
|
||||||
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isConstantExternalValue(Value value) {
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
region.walk([&](Operation* op) {
|
||||||
|
for (OpOperand& operand : op->getOpOperands()) {
|
||||||
|
Value value = operand.get();
|
||||||
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||||
|
<< kind << " body may only directly reference external constants";
|
||||||
|
diagnostic.attachNote(op->getLoc())
|
||||||
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||||
|
if (batchOp.getNumResults() == 0) {
|
||||||
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp)
|
||||||
|
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
||||||
|
if (yieldOp.getNumOperands() != 0)
|
||||||
|
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
||||||
|
}
|
||||||
|
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
||||||
|
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto laneArg = batchOp.getLaneArgument();
|
||||||
|
if (!laneArg)
|
||||||
|
return batchOp.emitError("compute_batch body must have a lane block argument");
|
||||||
for (auto& bodyOp : block) {
|
for (auto& bodyOp : block) {
|
||||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
|
||||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
return failure();
|
||||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
|
||||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
|
||||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -206,9 +310,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult SpatMVMOp::verify() {
|
LogicalResult SpatMVMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
return emitError("weight must be a shaped value");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
@@ -221,9 +325,9 @@ LogicalResult SpatMVMOp::verify() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatVMMOp::verify() {
|
LogicalResult SpatVMMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
return emitError("weight must be a shaped value");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
@@ -354,6 +458,21 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
|
|||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatCompute::verify() {
|
||||||
auto& block = getBody().front();
|
auto& block = getBody().front();
|
||||||
|
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||||
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
|
return emitError("compute body must have weight and input block arguments");
|
||||||
|
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||||
|
auto blockArg = getWeightArgument(weightIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
|
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||||
|
}
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||||
|
auto blockArg = getInputArgument(inputIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
|
return emitError("compute input block argument types must match input operand types exactly");
|
||||||
|
}
|
||||||
|
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
@@ -386,9 +505,11 @@ LogicalResult SpatCompute::verify() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto arg : block.getArguments())
|
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||||
if (arg.use_empty())
|
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return emitError("ComputeOp block argument is not used");
|
||||||
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||||
|
return failure();
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
@@ -397,44 +518,46 @@ LogicalResult SpatCompute::verify() {
|
|||||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||||
return verifyTensorChannelSizes(getOperation(),
|
return verifyTensorChannelSizes(getOperation(),
|
||||||
getInput().getType(),
|
getInput().getType(),
|
||||||
getChannelIds(),
|
getChannelIds().size(),
|
||||||
getSourceCoreIds(),
|
getSourceCoreIds().size(),
|
||||||
getTargetCoreIds(),
|
getTargetCoreIds().size(),
|
||||||
"channel_send_tensor");
|
"channel_send_tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||||
return verifyTensorChannelSizes(getOperation(),
|
return verifyTensorChannelSizes(getOperation(),
|
||||||
getOutput().getType(),
|
getOutput().getType(),
|
||||||
getChannelIds(),
|
getChannelIds().size(),
|
||||||
getSourceCoreIds(),
|
getSourceCoreIds().size(),
|
||||||
getTargetCoreIds(),
|
getTargetCoreIds().size(),
|
||||||
"channel_receive_tensor");
|
"channel_receive_tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
return verifyBatchChannelSizes(
|
||||||
|
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||||
return verifyTensorBatchChannelSizes(getOperation(),
|
return verifyTensorBatchChannelSizes(getOperation(),
|
||||||
getInput().getType(),
|
getInput().getType(),
|
||||||
getChannelIds(),
|
getChannelIds().size(),
|
||||||
getSourceCoreIds(),
|
getSourceCoreIds().size(),
|
||||||
getTargetCoreIds(),
|
getTargetCoreIds().size(),
|
||||||
"channel_send_tensor_batch");
|
"channel_send_tensor_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
return verifyBatchChannelSizes(
|
||||||
|
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||||
return verifyTensorBatchChannelSizes(getOperation(),
|
return verifyTensorBatchChannelSizes(getOperation(),
|
||||||
getOutput().getType(),
|
getOutput().getType(),
|
||||||
getChannelIds(),
|
getChannelIds().size(),
|
||||||
getSourceCoreIds(),
|
getSourceCoreIds().size(),
|
||||||
getTargetCoreIds(),
|
getTargetCoreIds().size(),
|
||||||
"channel_receive_tensor_batch");
|
"channel_receive_tensor_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -444,35 +567,6 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("laneCount must be positive");
|
return emitError("laneCount must be positive");
|
||||||
|
|
||||||
auto laneCountSz = static_cast<size_t>(count);
|
auto laneCountSz = static_cast<size_t>(count);
|
||||||
if (getWeights().size() % laneCountSz != 0)
|
|
||||||
return emitError("number of weights must be a multiple of laneCount");
|
|
||||||
|
|
||||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
|
||||||
return emitError("number of inputs must be either 0 or laneCount");
|
|
||||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
|
||||||
return emitError("number of outputs must be either 0 or laneCount");
|
|
||||||
|
|
||||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
|
||||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
|
||||||
Type weightType = getWeights()[weightIndex].getType();
|
|
||||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
|
||||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
|
||||||
return emitError("corresponding weights across lanes must have the same type");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!getInputs().empty()) {
|
|
||||||
Type inputType = getInputs()[0].getType();
|
|
||||||
for (Value in : getInputs().drop_front())
|
|
||||||
if (in.getType() != inputType)
|
|
||||||
return emitError("all inputs must have the same type");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!getOutputs().empty()) {
|
|
||||||
Type outputType = getOutputs()[0].getType();
|
|
||||||
for (Value out : getOutputs().drop_front())
|
|
||||||
if (out.getType() != outputType)
|
|
||||||
return emitError("all outputs must have the same type");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||||
@@ -482,27 +576,70 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("compute_batch coreIds array length must match laneCount");
|
return emitError("compute_batch coreIds array length must match laneCount");
|
||||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||||
return emitError("compute_batch coreIds values must be non-negative");
|
return emitError("compute_batch coreIds values must be non-negative");
|
||||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
DenseSet<int32_t> seenCoreIds;
|
||||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||||
if (!seenCoreIds.insert(coreId).second)
|
if (!seenCoreIds.insert(coreId).second)
|
||||||
return emitError("compute_batch coreIds values must be distinct");
|
return emitError("compute_batch coreIds values must be unique");
|
||||||
}
|
}
|
||||||
|
|
||||||
Block& block = getBody().front();
|
Block& block = getBody().front();
|
||||||
if (getInputs().empty()) {
|
if (block.getNumArguments() == 0)
|
||||||
if (block.getNumArguments() != 0)
|
return emitError("compute_batch body must have exactly one lane block argument");
|
||||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||||
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
|
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||||
|
auto laneArg = getLaneArgument();
|
||||||
|
if (!laneArg || !laneArg->getType().isIndex())
|
||||||
|
return emitError("compute_batch first block argument must have index type");
|
||||||
|
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||||
|
auto blockArg = getWeightArgument(weightIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
|
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||||
}
|
}
|
||||||
else {
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||||
if (block.getNumArguments() != 1)
|
auto blockArg = getInputArgument(inputIndex);
|
||||||
return emitError("compute_batch body must have exactly one block argument");
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||||
return emitError("body block argument type must match input type");
|
}
|
||||||
|
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||||
|
auto blockArg = getOutputArgument(resultIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != resultType)
|
||||||
|
return emitError("compute_batch output block argument types must match result types exactly");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||||
|
return failure();
|
||||||
|
return verifyBatchBody(*this, block);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatInParallelOp::verify() {
|
||||||
|
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||||
|
if (!batchOp)
|
||||||
|
return emitOpError("expected spat.compute_batch parent");
|
||||||
|
if (batchOp.getNumResults() == 0)
|
||||||
|
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||||
|
|
||||||
|
auto laneArg = batchOp.getLaneArgument();
|
||||||
|
if (!laneArg)
|
||||||
|
return emitOpError("expected compute_batch lane block argument");
|
||||||
|
for (Operation& op : getRegion().front().getOperations()) {
|
||||||
|
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||||
|
if (!insertSliceOp)
|
||||||
|
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||||
|
|
||||||
|
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||||
|
for (OpOperand& destination : destinations)
|
||||||
|
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||||
|
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "DCPAnalysis.hpp"
|
|
||||||
#include "../Scheduling/ComputeGraph.hpp"
|
#include "../Scheduling/ComputeGraph.hpp"
|
||||||
#include "../Scheduling/DcpScheduler.hpp"
|
#include "../Scheduling/DcpScheduler.hpp"
|
||||||
|
#include "DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
+1372
-540
File diff suppressed because it is too large
Load Diff
@@ -10,8 +10,7 @@ namespace spatial {
|
|||||||
|
|
||||||
class MergeScheduleMaterializer {
|
class MergeScheduleMaterializer {
|
||||||
public:
|
public:
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
|
||||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
|
|||||||
class ScopedMergePhaseTimer {
|
class ScopedMergePhaseTimer {
|
||||||
public:
|
public:
|
||||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||||
: enabled(isMergeProfilingEnabled()),
|
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||||
phase(phaseName.str()) {
|
|
||||||
if (enabled)
|
if (enabled)
|
||||||
start = std::chrono::steady_clock::now();
|
start = std::chrono::steady_clock::now();
|
||||||
}
|
}
|
||||||
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
|||||||
|
|
||||||
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
||||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
||||||
<< " compute=" << counts.topLevelComputeCount
|
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||||
<< " compute_batch=" << counts.topLevelComputeBatchCount
|
|
||||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
||||||
<< " tensor_send=" << counts.tensorChannelSendCount
|
<< " tensor_send=" << counts.tensorChannelSendCount
|
||||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount
|
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||||
<< " wvmm=" << counts.wvmmCount
|
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||||
<< " vadd=" << counts.vaddCount
|
|
||||||
<< " scf_for=" << counts.scfForCount << "\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
@@ -167,21 +163,21 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
|||||||
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) {
|
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
|
||||||
|
ValueRange sourceWeights) {
|
||||||
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights()))
|
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
|
||||||
targetWeightIndices[weight].push_back(weightIndex);
|
targetWeightIndices[weight].push_back(weightIndex);
|
||||||
|
|
||||||
DenseMap<Value, size_t> usedSourceWeightOccurrences;
|
DenseMap<Value, size_t> usedSourceWeightOccurrences;
|
||||||
SmallVector<size_t> sourceToTargetIndex;
|
SmallVector<size_t> sourceToTargetIndex;
|
||||||
sourceToTargetIndex.reserve(sourceWeights.size());
|
sourceToTargetIndex.reserve(sourceWeights.size());
|
||||||
auto targetWeights = target.getWeightsMutable();
|
|
||||||
for (Value weight : sourceWeights) {
|
for (Value weight : sourceWeights) {
|
||||||
size_t occurrence = usedSourceWeightOccurrences[weight]++;
|
size_t occurrence = usedSourceWeightOccurrences[weight]++;
|
||||||
auto& matchingIndices = targetWeightIndices[weight];
|
auto& matchingIndices = targetWeightIndices[weight];
|
||||||
if (occurrence >= matchingIndices.size()) {
|
if (occurrence >= matchingIndices.size()) {
|
||||||
size_t newIndex = target.getWeights().size();
|
size_t newIndex = targetWeights.size();
|
||||||
targetWeights.append(weight);
|
targetWeights.push_back(weight);
|
||||||
matchingIndices.push_back(newIndex);
|
matchingIndices.push_back(newIndex);
|
||||||
sourceToTargetIndex.push_back(newIndex);
|
sourceToTargetIndex.push_back(newIndex);
|
||||||
continue;
|
continue;
|
||||||
@@ -213,37 +209,50 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
auto& computeUse = *compute->getUses().begin();
|
auto& computeUse = *compute->getUses().begin();
|
||||||
auto child = cast<SpatCompute>(computeUse.getOwner());
|
auto child = cast<SpatCompute>(computeUse.getOwner());
|
||||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
auto childInputIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
SmallVector<Value> mergedWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(mergedWeights, child.getWeights());
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
SmallVector<Value> mergedInputs(compute.getInputs().begin(), compute.getInputs().end());
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), mergedWeights, mergedInputs);
|
||||||
|
Block* newBody = rewriter.createBlock(&newCompute.getBodyRegion());
|
||||||
|
for (Value weight : mergedWeights)
|
||||||
|
newBody->addArgument(weight.getType(), loc);
|
||||||
|
for (Value input : mergedInputs)
|
||||||
|
newBody->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights());
|
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) {
|
||||||
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
|
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||||
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex]));
|
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||||
|
assert(oldWeightArg && newWeightArg && "expected compute weight block arguments");
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
|
||||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
|
||||||
newTerminator->erase();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
|
||||||
auto remapWeightIndex = [&](auto weightedOp) {
|
|
||||||
auto oldIndex = weightedOp.getWeightIndex();
|
|
||||||
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
|
|
||||||
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto& op : child.getBody().front()) {
|
|
||||||
auto newInst = rewriter.clone(op, mapper);
|
|
||||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
|
|
||||||
remapWeightIndex(weightedMvmOp);
|
|
||||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
|
|
||||||
remapWeightIndex(weightedVmmOp);
|
|
||||||
}
|
}
|
||||||
|
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto oldInputArg = compute.getInputArgument(inputIndex);
|
||||||
|
auto newInputArg = newCompute.getInputArgument(inputIndex);
|
||||||
|
assert(oldInputArg && newInputArg && "expected compute input block arguments");
|
||||||
|
mapper.map(*oldInputArg, *newInputArg);
|
||||||
|
}
|
||||||
|
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) {
|
||||||
|
auto oldWeightArg = child.getWeightArgument(oldIndex);
|
||||||
|
auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]);
|
||||||
|
assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments");
|
||||||
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBody);
|
||||||
|
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
|
||||||
|
for (Operation& op : compute.getBody().front().without_terminator())
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
auto childInputArg = child.getInputArgument(childInputIndex);
|
||||||
|
assert(childInputArg && "expected child compute input block argument");
|
||||||
|
mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(newBody);
|
||||||
|
for (auto& op : child.getBody().front())
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
child.replaceAllUsesWith(newCompute);
|
||||||
toErase.insert(child);
|
toErase.insert(child);
|
||||||
@@ -651,12 +660,12 @@ public:
|
|||||||
|
|
||||||
emitMergeIrCounts("after-materialization", func);
|
emitMergeIrCounts("after-materialization", func);
|
||||||
|
|
||||||
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
emitMergeIrCounts("after-post-merge-compaction", func);
|
emitMergeIrCounts("after-post-merge-compaction", func);*/
|
||||||
|
|
||||||
{
|
{
|
||||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
@@ -61,6 +62,66 @@ std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
|||||||
|
|
||||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||||
|
|
||||||
|
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return constantValue.getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||||
|
uint64_t& channelId,
|
||||||
|
uint32_t& sourceCoreId,
|
||||||
|
uint32_t& targetCoreId) {
|
||||||
|
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||||
|
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||||
|
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||||
|
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||||
|
return false;
|
||||||
|
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||||
|
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||||
|
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||||
|
uint64_t& channelId,
|
||||||
|
uint32_t& sourceCoreId,
|
||||||
|
uint32_t& targetCoreId) {
|
||||||
|
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||||
|
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||||
|
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||||
|
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||||
|
return false;
|
||||||
|
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||||
|
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||||
|
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||||
|
SmallVector<Value> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||||
|
SmallVector<Value> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (int32_t value : values)
|
||||||
|
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||||
@@ -208,6 +269,7 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
|||||||
|
|
||||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||||
DenseSet<Operation*> consumed;
|
DenseSet<Operation*> consumed;
|
||||||
DenseMap<Operation*, size_t> computeOrder;
|
DenseMap<Operation*, size_t> computeOrder;
|
||||||
@@ -316,8 +378,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
|||||||
entries.reserve(group.size());
|
entries.reserve(group.size());
|
||||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||||
entries.push_back(
|
BatchReceiveEntry entry;
|
||||||
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||||
|
return;
|
||||||
|
entries.push_back(entry);
|
||||||
++opIts[groupIndex];
|
++opIts[groupIndex];
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
@@ -331,12 +395,15 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
|||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
|
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
|
||||||
|
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
|
||||||
|
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
|
||||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||||
receiveOp.getLoc(),
|
receiveOp.getLoc(),
|
||||||
receiveOp.getOutput().getType(),
|
receiveOp.getOutput().getType(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
channelIdValues,
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
sourceCoreIdValues,
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
targetCoreIdValues);
|
||||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -351,7 +418,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
|||||||
entries.reserve(group.size());
|
entries.reserve(group.size());
|
||||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||||
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
BatchSendEntry entry;
|
||||||
|
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||||
|
return;
|
||||||
|
entries.push_back(entry);
|
||||||
++opIts[groupIndex];
|
++opIts[groupIndex];
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
@@ -365,11 +435,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
|||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
|
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
|
||||||
|
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
|
||||||
|
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
|
||||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||||
sendOp.getLoc(),
|
sendOp.getLoc(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
channelIdValues,
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
sourceCoreIdValues,
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
targetCoreIdValues,
|
||||||
mapper.lookup(sendOp.getInput()));
|
mapper.lookup(sendOp.getInput()));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
@@ -30,7 +31,7 @@ enum class RegularStepKind {
|
|||||||
|
|
||||||
struct RegularStep {
|
struct RegularStep {
|
||||||
RegularStepKind kind;
|
RegularStepKind kind;
|
||||||
int32_t weightIndex = 0;
|
Value weight;
|
||||||
Value invariantOperand;
|
Value invariantOperand;
|
||||||
Type resultType;
|
Type resultType;
|
||||||
};
|
};
|
||||||
@@ -73,15 +74,90 @@ static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
|||||||
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||||
SmallVectorImpl<int32_t>& sourceCoreIds,
|
APInt constantValue;
|
||||||
SmallVectorImpl<int32_t>& targetCoreIds,
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
uint64_t channelId,
|
return failure();
|
||||||
uint32_t sourceCoreId,
|
return constantValue.getSExtValue();
|
||||||
uint32_t targetCoreId) {
|
}
|
||||||
channelIds.push_back(static_cast<int64_t>(channelId));
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
APInt constantValue;
|
||||||
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return failure();
|
||||||
|
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||||
|
uint64_t& channelId,
|
||||||
|
uint32_t& sourceCoreId,
|
||||||
|
uint32_t& targetCoreId) {
|
||||||
|
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||||
|
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||||
|
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||||
|
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||||
|
return false;
|
||||||
|
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||||
|
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||||
|
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||||
|
uint64_t& channelId,
|
||||||
|
uint32_t& sourceCoreId,
|
||||||
|
uint32_t& targetCoreId) {
|
||||||
|
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||||
|
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||||
|
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||||
|
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||||
|
return false;
|
||||||
|
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||||
|
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||||
|
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||||
|
SmallVector<Value> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||||
|
SmallVector<Value> constants;
|
||||||
|
constants.reserve(values.size());
|
||||||
|
for (int32_t value : values)
|
||||||
|
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||||
|
return constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Operation*> getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) {
|
||||||
|
SmallVector<Operation*> defs;
|
||||||
|
defs.reserve(metadataOperandCount);
|
||||||
|
for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) {
|
||||||
|
Operation* def = channelOp->getOperand(operandIndex).getDefiningOp();
|
||||||
|
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(def);
|
||||||
|
if (!constantOp || def->getBlock() != channelOp->getBlock())
|
||||||
|
continue;
|
||||||
|
defs.push_back(def);
|
||||||
|
}
|
||||||
|
llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); });
|
||||||
|
return defs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) {
|
||||||
|
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||||
|
metadataDef->moveBefore(insertionPoint);
|
||||||
|
channelOp->moveBefore(insertionPoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) {
|
||||||
|
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||||
|
metadataDef->moveBefore(block, insertionPoint);
|
||||||
|
channelOp->moveBefore(block, insertionPoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||||
@@ -196,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand
|
||||||
&& lhs.resultType == rhs.resultType;
|
&& lhs.resultType == rhs.resultType;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,8 +303,7 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
|||||||
chunk.input = startOp.getInput();
|
chunk.input = startOp.getInput();
|
||||||
chunk.output = startOp.getOutput();
|
chunk.output = startOp.getOutput();
|
||||||
chunk.ops.push_back(startOp.getOperation());
|
chunk.ops.push_back(startOp.getOperation());
|
||||||
chunk.steps.push_back(
|
chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()});
|
||||||
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
|
|
||||||
|
|
||||||
Value currentValue = startOp.getOutput();
|
Value currentValue = startOp.getOutput();
|
||||||
while (currentValue.hasOneUse()) {
|
while (currentValue.hasOneUse()) {
|
||||||
@@ -241,9 +316,9 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
if (vaddOp.getLhs() == currentValue)
|
if (vaddOp.getLhs() == currentValue)
|
||||||
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||||
else if (vaddOp.getRhs() == currentValue)
|
else if (vaddOp.getRhs() == currentValue)
|
||||||
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||||
else
|
else
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -255,7 +330,8 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
|||||||
return chunk;
|
return chunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
static RegularCompactionResult
|
||||||
|
compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run, OperationFolder& constantFolder) {
|
||||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||||
const RegularChunk& anchorChunk = run.front();
|
const RegularChunk& anchorChunk = run.front();
|
||||||
RegularCompactionResult result;
|
RegularCompactionResult result;
|
||||||
@@ -275,9 +351,9 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
|||||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
||||||
auto packedInit = tensor::EmptyOp::create(
|
auto packedInit = tensor::EmptyOp::create(
|
||||||
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
|
auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder);
|
||||||
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
|
auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast<int64_t>(run.size()), constantFolder);
|
||||||
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
|
auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder);
|
||||||
auto loop =
|
auto loop =
|
||||||
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||||
|
|
||||||
@@ -290,8 +366,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
|||||||
|
|
||||||
Value inputRowOffset = iv;
|
Value inputRowOffset = iv;
|
||||||
if (inputType.getDimSize(0) != 1) {
|
if (inputType.getDimSize(0) != 1) {
|
||||||
auto rowsPerValue =
|
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder);
|
||||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
|
|
||||||
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,8 +395,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
|||||||
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
||||||
Value outputRowOffset = iv;
|
Value outputRowOffset = iv;
|
||||||
if (outputType.getDimSize(0) != 1) {
|
if (outputType.getDimSize(0) != 1) {
|
||||||
auto rowsPerValue =
|
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder);
|
||||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
|
|
||||||
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,35 +463,50 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
|||||||
Block& block = compute.getBody().front();
|
Block& block = compute.getBody().front();
|
||||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||||
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||||
|
Operation* firstForwardedSend = nullptr;
|
||||||
|
|
||||||
for (Operation& op : block) {
|
for (Operation& op : block) {
|
||||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||||
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
uint64_t channelId = 0;
|
||||||
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
uint32_t sourceCoreId = 0;
|
||||||
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
uint32_t targetCoreId = 0;
|
||||||
|
if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId)
|
||||||
|
&& sourceCoreId == static_cast<uint32_t>(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||||
|
if (!firstForwardedSend)
|
||||||
|
firstForwardedSend = sendOp.getOperation();
|
||||||
|
uint64_t key = getEndpointKey(sourceCoreId, targetCoreId);
|
||||||
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||||
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
uint64_t channelId = 0;
|
||||||
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||||
|
|| targetCoreId != static_cast<uint32_t>(coreId) || sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), sourceCoreId);
|
||||||
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||||
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||||
moves.push_back({receiveOp, firstMatchingSend->second});
|
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||||
|
else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp))
|
||||||
|
moves.push_back({receiveOp, firstForwardedSend});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [receiveOp, insertionPoint] : moves)
|
for (auto [receiveOp, insertionPoint] : moves)
|
||||||
receiveOp->moveBefore(insertionPoint);
|
moveScalarChannelBundleBefore(receiveOp, insertionPoint);
|
||||||
|
|
||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||||
|
|| sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -425,18 +514,32 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
|||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||||
|
uint64_t currentChannelId = 0;
|
||||||
|
uint32_t currentSourceCoreId = 0;
|
||||||
|
uint32_t currentTargetCoreId = 0;
|
||||||
return current.getOutput().getType() == outputType
|
return current.getOutput().getType() == outputType
|
||||||
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
&& getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId)
|
||||||
|
&& currentSourceCoreId < static_cast<uint32_t>(coreId);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (run.ops.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||||
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
uint64_t lhsChannelId = 0;
|
||||||
|
uint32_t lhsSourceCoreId = 0;
|
||||||
|
uint32_t lhsTargetCoreId = 0;
|
||||||
|
uint64_t rhsChannelId = 0;
|
||||||
|
uint32_t rhsSourceCoreId = 0;
|
||||||
|
uint32_t rhsTargetCoreId = 0;
|
||||||
|
bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId);
|
||||||
|
bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId);
|
||||||
|
if (!lhsHasMetadata || !rhsHasMetadata)
|
||||||
|
return false;
|
||||||
|
return lhsSourceCoreId > rhsSourceCoreId;
|
||||||
});
|
});
|
||||||
Block::iterator insertIt = run.end;
|
Block::iterator insertIt = run.end;
|
||||||
for (auto op : sorted)
|
for (auto op : sorted)
|
||||||
op->moveBefore(&block, insertIt);
|
moveScalarChannelBundleBefore(op, &block, insertIt);
|
||||||
}
|
}
|
||||||
|
|
||||||
it = run.end;
|
it = run.end;
|
||||||
@@ -446,6 +549,7 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
Block& block = compute.getBody().front();
|
Block& block = compute.getBody().front();
|
||||||
@@ -461,7 +565,14 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
bool hasRepeatedEndpoint = false;
|
bool hasRepeatedEndpoint = false;
|
||||||
DenseSet<uint64_t> seenEndpoints;
|
DenseSet<uint64_t> seenEndpoints;
|
||||||
for (auto op : run.ops) {
|
for (auto op : run.ops) {
|
||||||
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||||
|
hasRepeatedEndpoint = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId);
|
||||||
if (!seenEndpoints.insert(endpointKey).second) {
|
if (!seenEndpoints.insert(endpointKey).second) {
|
||||||
hasRepeatedEndpoint = true;
|
hasRepeatedEndpoint = true;
|
||||||
break;
|
break;
|
||||||
@@ -478,8 +589,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
};
|
};
|
||||||
SmallVector<ReceiveEntry> sortedEntries;
|
SmallVector<ReceiveEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.ops.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
for (auto [originalIndex, op] : llvm::enumerate(run.ops)) {
|
||||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||||
|
sortedEntries.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId});
|
||||||
|
}
|
||||||
|
if (sortedEntries.empty()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
@@ -488,8 +611,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(sortedEntries.size());
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
for (ReceiveEntry& entry : sortedEntries) {
|
for (ReceiveEntry& entry : sortedEntries) {
|
||||||
appendChannelAttrs(
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||||
@@ -506,13 +630,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.ops.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||||
run.ops.front().getLoc(),
|
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||||
packedType,
|
auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues);
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
||||||
if (concatOp && concatPackedType) {
|
if (concatOp && concatPackedType) {
|
||||||
replaceConcatRunWithPackedValue(concatOp,
|
replaceConcatRunWithPackedValue(concatOp,
|
||||||
concatStartIndex,
|
concatStartIndex,
|
||||||
@@ -551,8 +673,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
};
|
};
|
||||||
SmallVector<SendEntry> sortedEntries;
|
SmallVector<SendEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.ops.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto op : run.ops)
|
for (auto op : run.ops) {
|
||||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||||
|
sortedEntries.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId});
|
||||||
|
}
|
||||||
|
if (sortedEntries.empty()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
@@ -563,20 +697,22 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
inputs.reserve(sortedEntries.size());
|
inputs.reserve(sortedEntries.size());
|
||||||
for (SendEntry& entry : sortedEntries) {
|
for (SendEntry& entry : sortedEntries) {
|
||||||
appendChannelAttrs(
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
inputs.push_back(entry.op.getInput());
|
inputs.push_back(entry.op.getInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.ops.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||||
run.ops.front().getLoc(),
|
SmallVector<Value> sourceCoreIdValues =
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
SmallVector<Value> targetCoreIdValues =
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||||
packedInput);
|
spatial::SpatChannelSendTensorOp::create(
|
||||||
|
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
||||||
for (auto op : run.ops)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
@@ -606,9 +742,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (run.ops.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<Value> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<Value> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<Value> targetCoreIds;
|
||||||
for (auto op : run.ops) {
|
for (auto op : run.ops) {
|
||||||
llvm::append_range(channelIds, op.getChannelIds());
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
@@ -629,13 +765,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.ops.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(
|
||||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds);
|
||||||
run.ops.front().getLoc(),
|
|
||||||
packedType,
|
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
||||||
if (concatOp && concatPackedType) {
|
if (concatOp && concatPackedType) {
|
||||||
replaceConcatRunWithPackedValue(
|
replaceConcatRunWithPackedValue(
|
||||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||||
@@ -663,9 +794,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (run.ops.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<Value> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<Value> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<Value> targetCoreIds;
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(run.ops.size());
|
inputs.reserve(run.ops.size());
|
||||||
for (auto op : run.ops) {
|
for (auto op : run.ops) {
|
||||||
@@ -678,12 +809,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
rewriter.setInsertionPoint(run.ops.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
spatial::SpatChannelSendTensorBatchOp::create(
|
||||||
run.ops.front().getLoc(),
|
rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput);
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
|
||||||
packedInput);
|
|
||||||
for (auto op : run.ops)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
@@ -700,6 +827,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
void compactRegularOpRuns(func::FuncOp funcOp) {
|
void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
|
|
||||||
auto compactInBlock = [&](Block& block) {
|
auto compactInBlock = [&](Block& block) {
|
||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
@@ -740,7 +868,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
for (const RegularChunk& chunk : run)
|
for (const RegularChunk& chunk : run)
|
||||||
originalOpCount += chunk.ops.size();
|
originalOpCount += chunk.ops.size();
|
||||||
|
|
||||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder);
|
||||||
if (result.changed) {
|
if (result.changed) {
|
||||||
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||||
if (!result.resumeAfter) {
|
if (!result.resumeAfter) {
|
||||||
@@ -763,6 +891,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
OperationFolder constantFolder(funcOp.getContext());
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
Block& block = compute.getBody().front();
|
Block& block = compute.getBody().front();
|
||||||
@@ -784,7 +913,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||||
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
if (current.getWeight() != wvmmOp.getWeight()
|
||||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||||
@@ -851,9 +980,9 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.ops.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder);
|
||||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder);
|
||||||
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder);
|
||||||
auto packedInit =
|
auto packedInit =
|
||||||
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||||
auto loop =
|
auto loop =
|
||||||
@@ -868,7 +997,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
Value sourceRow = iv;
|
Value sourceRow = iv;
|
||||||
if (firstRow != 0) {
|
if (firstRow != 0) {
|
||||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder);
|
||||||
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -883,7 +1012,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
extractSizes,
|
extractSizes,
|
||||||
extractStrides);
|
extractStrides);
|
||||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult());
|
||||||
|
|
||||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <iterator>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation *producerOp) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
||||||
|
auto offsetValue = llvm::dyn_cast<Value>(offset);
|
||||||
|
return offsetValue == laneArg;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||||
|
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||||
|
if (inputIt == batch.getInputs().end())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt);
|
||||||
|
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||||
|
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||||
|
if (!inputArg || !laneArg)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
Weight projectedCost = 0;
|
||||||
|
for (Operation* user : inputArg->getUsers()) {
|
||||||
|
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||||
|
if (!extract || extract.getSource() != *inputArg)
|
||||||
|
return std::nullopt;
|
||||||
|
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (projectedCost == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
return projectedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||||
|
auto inputType = cast<ShapedType>(input.getType());
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||||
|
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||||
|
return *projectedCost;
|
||||||
|
return static_cast<Weight>(getSizeInBytes(inputType));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
for (const ComputeGraphEdge& edge : edges) {
|
for (const ComputeGraphEdge& edge : edges) {
|
||||||
@@ -99,8 +142,7 @@ CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
|||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
return getSpatComputeCrossbarUsage(spatCompute);
|
return getSpatComputeCrossbarUsage(spatCompute);
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||||
static_cast<CrossbarUsage>(instance.laneCount));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||||
@@ -114,7 +156,8 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
|
|||||||
continue;
|
continue;
|
||||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||||
size_t index = graph.nodes.size();
|
size_t index = graph.nodes.size();
|
||||||
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
graph.nodes.push_back(
|
||||||
|
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
graph.instanceToIndex[instance] = index;
|
graph.instanceToIndex[instance] = index;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -136,15 +179,27 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
|
|||||||
|
|
||||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||||
auto producerInstance = getComputeProducerInstance(input);
|
for (Value input : inputs) {
|
||||||
|
Weight transferCost = getInputTransferCost(node.instance, input);
|
||||||
|
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||||
|
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||||
|
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||||
|
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||||
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
|
continue;
|
||||||
|
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto producerInstance = getComputeProducerInstance(input, &node.instance);
|
||||||
if (!producerInstance)
|
if (!producerInstance)
|
||||||
continue;
|
continue;
|
||||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||||
if (producerIt == graph.instanceToIndex.end())
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
continue;
|
continue;
|
||||||
rawEdges.push_back(
|
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,8 +37,7 @@ struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
|||||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
|
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
|
||||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||||
}
|
}
|
||||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
|
||||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
|
||||||
return lhs == rhs;
|
return lhs == rhs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
+94
-31
@@ -1,6 +1,8 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "ComputeInstanceUtils.hpp"
|
#include "ComputeInstanceUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
@@ -18,40 +20,81 @@ size_t getSchedulingCpuBudget() {
|
|||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
assert(laneCount > 0 && "laneCount must be positive");
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
return static_cast<size_t>(laneCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
size_t totalLanes = batch.getLaneCount();
|
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
|
|
||||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
|
||||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
|
||||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||||
size_t totalLanes = batch.getLaneCount();
|
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
return {batch.getOperation(), lane, 1};
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
|
||||||
|
|
||||||
size_t chunkIndex = 0;
|
|
||||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
|
||||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
|
||||||
else
|
|
||||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
|
||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
|
||||||
|
if (extract.getMixedOffsets().empty())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
OpFoldResult offset = extract.getMixedOffsets().front();
|
||||||
|
if (Attribute attr = llvm::dyn_cast<Attribute>(offset)) {
|
||||||
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
if (!intAttr || intAttr.getInt() < 0)
|
||||||
|
return std::nullopt;
|
||||||
|
return static_cast<uint32_t>(intAttr.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
Value offsetValue = llvm::cast<Value>(offset);
|
||||||
|
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||||
|
if (constantIndex.value() < 0)
|
||||||
|
return std::nullopt;
|
||||||
|
return static_cast<uint32_t>(constantIndex.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<ProducerValueRef> getResultfulBatchProducerValueRef(SpatComputeBatch batch,
|
||||||
|
const ComputeInstance* consumerInstance) {
|
||||||
|
if (!consumerInstance)
|
||||||
|
return std::nullopt;
|
||||||
|
if (!isa<SpatComputeBatch>(consumerInstance->op))
|
||||||
|
return std::nullopt;
|
||||||
|
if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast<uint32_t>(batch.getLaneCount()))
|
||||||
|
return std::nullopt;
|
||||||
|
return ProducerValueRef {
|
||||||
|
{batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount},
|
||||||
|
0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ProducerValueRef> getProducerValueRef(Value value, const ComputeInstance* consumerInstance) {
|
||||||
Operation* op = value.getDefiningOp();
|
Operation* op = value.getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
Value source = extract.getSource();
|
||||||
|
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
|
||||||
|
if (batch && batch.getNumResults() != 0) {
|
||||||
|
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
|
||||||
|
if (*lane >= static_cast<uint32_t>(batch.getLaneCount()))
|
||||||
|
return std::nullopt;
|
||||||
|
return ProducerValueRef {
|
||||||
|
{batch.getOperation(), *lane, 1},
|
||||||
|
0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
value = source;
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||||
return ProducerValueRef {
|
return ProducerValueRef {
|
||||||
ComputeInstance {compute.getOperation(), 0, 1},
|
ComputeInstance {compute.getOperation(), 0, 1},
|
||||||
@@ -60,6 +103,8 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||||
|
if (batch.getNumResults() != 0)
|
||||||
|
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||||
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
||||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||||
size_t resultIndex = lane - instance.laneStart;
|
size_t resultIndex = lane - instance.laneStart;
|
||||||
@@ -69,8 +114,8 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
std::optional<ComputeInstance> getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) {
|
||||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value, consumerInstance))
|
||||||
return producer->instance;
|
return producer->instance;
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
@@ -80,11 +125,18 @@ llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &inst
|
|||||||
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||||
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
if (batch.getNumResults() != 0)
|
||||||
|
return llvm::SmallVector<Value, 4>(batch.getInputs().begin(), batch.getInputs().end());
|
||||||
|
|
||||||
|
assert(batch.getInputs().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||||
|
&& "resultless compute_batch inputs must be evenly partitioned by lane");
|
||||||
|
size_t inputsPerLane = batch.getInputs().size() / static_cast<size_t>(batch.getLaneCount());
|
||||||
llvm::SmallVector<Value, 4> inputs;
|
llvm::SmallVector<Value, 4> inputs;
|
||||||
inputs.reserve(instance.laneCount);
|
inputs.reserve(instance.laneCount * inputsPerLane);
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||||
if (!batch.getInputs().empty())
|
size_t firstInput = static_cast<size_t>(lane) * inputsPerLane;
|
||||||
inputs.push_back(batch.getInputs()[lane]);
|
inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane);
|
||||||
|
}
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,10 +145,18 @@ llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &ins
|
|||||||
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
if (batch.getNumResults() != 0)
|
||||||
|
return llvm::SmallVector<Value, 4>(batch.getWeights().begin(), batch.getWeights().end());
|
||||||
|
|
||||||
|
assert(batch.getWeights().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||||
|
&& "resultless compute_batch weights must be evenly partitioned by lane");
|
||||||
|
size_t weightsPerLane = batch.getWeights().size() / static_cast<size_t>(batch.getLaneCount());
|
||||||
llvm::SmallVector<Value, 4> weights;
|
llvm::SmallVector<Value, 4> weights;
|
||||||
weights.reserve(instance.laneCount);
|
weights.reserve(instance.laneCount * weightsPerLane);
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||||
weights.push_back(batch.getWeights()[lane]);
|
size_t firstWeight = static_cast<size_t>(lane) * weightsPerLane;
|
||||||
|
weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane);
|
||||||
|
}
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +165,9 @@ llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance
|
|||||||
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||||
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
if (batch.getNumResults() != 0)
|
||||||
|
return llvm::SmallVector<Value, 4>(batch.getResults().begin(), batch.getResults().end());
|
||||||
|
|
||||||
llvm::SmallVector<Value, 4> outputs;
|
llvm::SmallVector<Value, 4> outputs;
|
||||||
outputs.reserve(instance.laneCount);
|
outputs.reserve(instance.laneCount);
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
|||||||
+4
-2
@@ -26,8 +26,10 @@ size_t getBatchChunkTargetCount(int32_t laneCount);
|
|||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||||
|
|
||||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
const ComputeInstance* consumerInstance = nullptr);
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
|
||||||
|
const ComputeInstance* consumerInstance = nullptr);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance& instance);
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance& instance);
|
||||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance& instance);
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance& instance);
|
||||||
|
|||||||
@@ -10,8 +10,8 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DcpScheduler.hpp"
|
|
||||||
#include "../DCPGraph/Graph.hpp"
|
#include "../DCPGraph/Graph.hpp"
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
@@ -404,7 +404,8 @@ bool coarsenGraph(const VirtualGraph &graph,
|
|||||||
VirtualNode mergedNode;
|
VirtualNode mergedNode;
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
|
||||||
|
memberNode.originalNodeIndices.end());
|
||||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||||
}
|
}
|
||||||
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
MergeScheduleResult
|
||||||
|
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
|
||||||
llvm::SmallVector<Weight> nodeWeights;
|
llvm::SmallVector<Weight> nodeWeights;
|
||||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
|
|||||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||||
|
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
+10
-15
@@ -1,13 +1,13 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ComputeGraph.hpp"
|
|
||||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
#include "DcpScheduler.hpp"
|
#include "DcpScheduler.hpp"
|
||||||
#include "MergeSchedulingAnalysis.hpp"
|
#include "MergeSchedulingAnalysis.hpp"
|
||||||
#include "PeftScheduler.hpp"
|
#include "PeftScheduler.hpp"
|
||||||
@@ -20,10 +20,8 @@ namespace {
|
|||||||
|
|
||||||
MergeSchedulerKind getSchedulerKind() {
|
MergeSchedulerKind getSchedulerKind() {
|
||||||
switch (pimMergeScheduler.getValue()) {
|
switch (pimMergeScheduler.getValue()) {
|
||||||
case MergeSchedulerPeft:
|
case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
|
||||||
return MergeSchedulerKind::Peft;
|
case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
|
||||||
case MergeSchedulerDcp:
|
|
||||||
return MergeSchedulerKind::Dcp;
|
|
||||||
}
|
}
|
||||||
llvm_unreachable("unknown merge scheduler kind");
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
}
|
}
|
||||||
@@ -115,19 +113,16 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
|||||||
|
|
||||||
MergeScheduleResult schedule;
|
MergeScheduleResult schedule;
|
||||||
if (options.kind == MergeSchedulerKind::Peft) {
|
if (options.kind == MergeSchedulerKind::Peft) {
|
||||||
schedule = runPeftScheduler(
|
schedule = runPeftScheduler(graph,
|
||||||
graph,
|
PeftScheduleOptions {options.processorCount,
|
||||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||||
entryOp->getContext()});
|
entryOp->getContext()});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
schedule = runDcpScheduler(
|
schedule = runDcpScheduler(graph,
|
||||||
graph,
|
DcpScheduleOptions {options.processorCount,
|
||||||
DcpScheduleOptions {
|
|
||||||
options.processorCount,
|
|
||||||
dcpCriticalWindowSize.getValue(),
|
dcpCriticalWindowSize.getValue(),
|
||||||
options.allowDcpFallbackForAutoCoreCount
|
options.allowDcpFallbackForAutoCoreCount},
|
||||||
},
|
|
||||||
entryOp->getContext());
|
entryOp->getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+117
-20
@@ -19,7 +19,6 @@ struct ScheduledTask {
|
|||||||
size_t processor = std::numeric_limits<size_t>::max();
|
size_t processor = std::numeric_limits<size_t>::max();
|
||||||
Time startTime = 0;
|
Time startTime = 0;
|
||||||
Time endTime = 0;
|
Time endTime = 0;
|
||||||
size_t slot = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||||
@@ -44,7 +43,6 @@ std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph &graph) {
|
|||||||
levelNodes.push_back(node);
|
levelNodes.push_back(node);
|
||||||
++levelizedCount;
|
++levelizedCount;
|
||||||
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||||
(void) weight;
|
|
||||||
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||||
if (--remainingSuccessors[pred] == 0)
|
if (--remainingSuccessors[pred] == 0)
|
||||||
readySinks.push(pred);
|
readySinks.push(pred);
|
||||||
@@ -88,9 +86,14 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
verifyOctTableSize(nodeCount, processorCount);
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
|
|
||||||
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
|
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||||
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||||
|
|
||||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
|
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||||
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||||
auto computeNodeOct = [&](size_t levelIndex) {
|
auto computeNodeOct = [&](size_t levelIndex) {
|
||||||
size_t task = levelNodes[levelIndex];
|
size_t task = levelNodes[levelIndex];
|
||||||
@@ -99,7 +102,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
for (const auto& [succ, comm] : graph.successors[task]) {
|
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||||
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], graph.nodes[succ].weight);
|
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||||
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||||
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||||
}
|
}
|
||||||
@@ -108,7 +111,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
Time minForPreds = std::numeric_limits<Time>::max();
|
Time minForPreds = std::numeric_limits<Time>::max();
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
oct[task * processorCount + processor] = maxVals[processor];
|
oct[task * processorCount + processor] = maxVals[processor];
|
||||||
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], graph.nodes[task].weight));
|
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||||
}
|
}
|
||||||
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||||
};
|
};
|
||||||
@@ -132,6 +135,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||||
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options.context != nullptr)
|
if (options.context != nullptr)
|
||||||
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||||
else
|
else
|
||||||
@@ -157,7 +161,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> scheduled(nodeCount, false);
|
std::vector<char> scheduled(nodeCount, false);
|
||||||
std::vector<Time> processorAvailable(processorCount, 0);
|
|
||||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||||
std::vector<ScheduledTask> schedules(nodeCount);
|
std::vector<ScheduledTask> schedules(nodeCount);
|
||||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||||
@@ -176,8 +179,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
if (graph.nodes[task].crossbarUsage != 0 &&
|
if (graph.nodes[task].crossbarUsage != 0
|
||||||
addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||||
crossbarRejected = true;
|
crossbarRejected = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -189,13 +192,33 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
}
|
}
|
||||||
|
|
||||||
Time est = std::max(processorAvailable[processor], dataReady);
|
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||||
Time eft = addOrMax(est, graph.nodes[task].weight);
|
Time compWeight = getComputeCost(task, processor);
|
||||||
|
Time est = dataReady;
|
||||||
|
Time currentEnd = 0;
|
||||||
|
bool foundGap = false;
|
||||||
|
|
||||||
|
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||||
|
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||||
|
Time gapStart = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||||
|
est = gapStart;
|
||||||
|
foundGap = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
currentEnd = schedTask.endTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!foundGap)
|
||||||
|
est = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
Time eft = addOrMax(est, compWeight);
|
||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft) ||
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
(oeft == bestOeft && eft == bestEft && est < bestEst) ||
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||||
(oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
@@ -219,12 +242,15 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
llvm::report_fatal_error(llvm::StringRef(message));
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
}
|
}
|
||||||
|
|
||||||
schedules[task] = {bestProcessor, bestEst, bestEft, tasksByProcessor[bestProcessor].size()};
|
schedules[task] = {bestProcessor, bestEst, bestEft};
|
||||||
scheduled[task] = true;
|
scheduled[task] = true;
|
||||||
++scheduledCount;
|
++scheduledCount;
|
||||||
processorAvailable[bestProcessor] = bestEft;
|
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||||
processorCrossbars[bestProcessor] =
|
|
||||||
addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
// 3. CRITICAL FIX: Topological Append
|
||||||
|
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||||
|
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||||
|
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||||
tasksByProcessor[bestProcessor].push_back(task);
|
tasksByProcessor[bestProcessor].push_back(task);
|
||||||
|
|
||||||
for (const auto& [child, weight] : graph.successors[task]) {
|
for (const auto& [child, weight] : graph.successors[task]) {
|
||||||
@@ -238,16 +264,86 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
if (scheduledCount != nodeCount)
|
if (scheduledCount != nodeCount)
|
||||||
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||||
|
|
||||||
|
// 4. Build Strict Topological Dominance Order
|
||||||
|
std::vector<size_t> scheduledOrder(nodeCount);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
scheduledOrder[i] = i;
|
||||||
|
|
||||||
|
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||||
|
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 5. Check if equal schedule in two level
|
||||||
|
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||||
|
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
|
||||||
|
for (size_t controlProcessor = currentProcessor; controlProcessor < processorCount; ++controlProcessor) {
|
||||||
|
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
|
||||||
|
continue;
|
||||||
|
auto& currentTasks = tasksByProcessor[currentProcessor];
|
||||||
|
auto& controlTasks = tasksByProcessor[controlProcessor];
|
||||||
|
bool equalSchedule = true;
|
||||||
|
|
||||||
|
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
|
||||||
|
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
|
||||||
|
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
|
||||||
|
if (currentComputeInstance.op != controlComputeInstance.op
|
||||||
|
|| currentComputeInstance.laneCount != controlComputeInstance.laneCount) {
|
||||||
|
equalSchedule = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (equalSchedule) {
|
||||||
|
equivalentClass[currentProcessor].push_back(controlProcessor);
|
||||||
|
equivalentClass[controlProcessor].push_back(currentProcessor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*{
|
||||||
|
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||||
|
std::vector<bool> visited(processorCount, false);
|
||||||
|
size_t uniqueClassCount = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < processorCount; ++i) {
|
||||||
|
if (visited[i])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// We found a new unique schedule (equivalence class)
|
||||||
|
++uniqueClassCount;
|
||||||
|
visited[i] = true;
|
||||||
|
|
||||||
|
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||||
|
|
||||||
|
// Find and mark all identical companions
|
||||||
|
auto it = equivalentClass.find(i);
|
||||||
|
if (it != equivalentClass.end()) {
|
||||||
|
for (size_t eqCpu : it->second) {
|
||||||
|
if (!visited[eqCpu]) {
|
||||||
|
llvm::dbgs() << ", " << eqCpu;
|
||||||
|
visited[eqCpu] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm::dbgs() << " }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||||
|
llvm::dbgs() << "--------------------------------------\n";
|
||||||
|
}*/
|
||||||
|
|
||||||
|
// 6. Populate Final Result
|
||||||
MergeScheduleResult result;
|
MergeScheduleResult result;
|
||||||
result.dominanceOrderCompute.reserve(nodeCount);
|
result.dominanceOrderCompute.reserve(nodeCount);
|
||||||
for (const ComputeGraphNode &node : graph.nodes)
|
|
||||||
result.dominanceOrderCompute.push_back(node.instance);
|
for (size_t task : scheduledOrder)
|
||||||
|
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
size_t currentSlot = 0;
|
||||||
for (size_t task : tasksByProcessor[processor]) {
|
for (size_t task : tasksByProcessor[processor]) {
|
||||||
const ComputeInstance instance = graph.nodes[task].instance;
|
const ComputeInstance instance = graph.nodes[task].instance;
|
||||||
result.computeToCpuMap[instance] = processor;
|
result.computeToCpuMap[instance] = processor;
|
||||||
result.computeToCpuSlotMap[instance] = schedules[task].slot;
|
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||||
result.computeToAestMap[instance] = schedules[task].startTime;
|
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||||
}
|
}
|
||||||
if (!tasksByProcessor[processor].empty()) {
|
if (!tasksByProcessor[processor].empty()) {
|
||||||
@@ -257,8 +353,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result.equivalentClass = equivalentClass;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
initType,
|
initType,
|
||||||
|
|||||||
@@ -176,9 +176,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
if (!hasByteSizedElementType(sourceType.getElementType()))
|
||||||
if (elementByteWidth <= 0)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
if (size != totalBytes)
|
if (size != totalBytes)
|
||||||
@@ -268,24 +268,31 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
||||||
|
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
||||||
|
if (failed(dstOffset) || failed(srcOffset))
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
auto status = rewriteSubviewCopyLikeOp(
|
||||||
copyOp,
|
copyOp,
|
||||||
copyOp.getDeviceTarget(),
|
copyOp.getDeviceTarget(),
|
||||||
copyOp.getHostSource(),
|
copyOp.getHostSource(),
|
||||||
copyOp.getDeviceTargetOffset(),
|
*dstOffset,
|
||||||
copyOp.getHostSourceOffset(),
|
*srcOffset,
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
/*allowLoopRewrite=*/true,
|
/*allowLoopRewrite=*/true,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
Value dstOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
|
||||||
|
Value srcOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
|
||||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
copyOp.getLoc(),
|
copyOp.getLoc(),
|
||||||
resultType,
|
resultType,
|
||||||
|
dstOffsetValue,
|
||||||
|
srcOffsetValue,
|
||||||
dst,
|
dst,
|
||||||
src,
|
src,
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
});
|
});
|
||||||
if (failed(status))
|
if (failed(status))
|
||||||
@@ -301,24 +308,31 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
||||||
|
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
||||||
|
if (failed(dstOffset) || failed(srcOffset))
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
auto status = rewriteSubviewCopyLikeOp(
|
||||||
copyOp,
|
copyOp,
|
||||||
copyOp.getHostTarget(),
|
copyOp.getHostTarget(),
|
||||||
copyOp.getDeviceSource(),
|
copyOp.getDeviceSource(),
|
||||||
copyOp.getHostTargetOffset(),
|
*dstOffset,
|
||||||
copyOp.getDeviceSourceOffset(),
|
*srcOffset,
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
/*allowLoopRewrite=*/false,
|
/*allowLoopRewrite=*/false,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
Value dstOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
|
||||||
|
Value srcOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
copyOp.getLoc(),
|
copyOp.getLoc(),
|
||||||
resultType,
|
resultType,
|
||||||
|
dstOffset,
|
||||||
|
srcOffset,
|
||||||
dst,
|
dst,
|
||||||
src,
|
src,
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
});
|
});
|
||||||
if (failed(status))
|
if (failed(status))
|
||||||
@@ -355,9 +369,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
|||||||
if (failed(staticOffsets))
|
if (failed(staticOffsets))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
auto resultMemRefType = cast<MemRefType>(subviewOp.getType());
|
||||||
auto resultMemRefType =
|
|
||||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
|
||||||
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
||||||
if (failed(foldedAttr))
|
if (failed(foldedAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -23,23 +23,19 @@ namespace {
|
|||||||
|
|
||||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
return operandIndex == 1;
|
return operandIndex == 3;
|
||||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||||
return operandIndex == 1;
|
return operandIndex == 1;
|
||||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return operandIndex == 0;
|
return operandIndex == 2;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t getValueSizeInBytes(Value value) {
|
|
||||||
auto type = dyn_cast<ShapedType>(value.getType());
|
|
||||||
if (!type || !type.hasStaticShape())
|
|
||||||
return -1;
|
|
||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CoreOpTy>
|
template <typename CoreOpTy>
|
||||||
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
|
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||||
|
IRRewriter& rewriter,
|
||||||
|
OperationFolder& constantFolder,
|
||||||
|
bool& hasFailure) {
|
||||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||||
SmallVector<Operation*> ops;
|
SmallVector<Operation*> ops;
|
||||||
coreOp.getBody().front().walk([&](Operation* op) {
|
coreOp.getBody().front().walk([&](Operation* op) {
|
||||||
@@ -48,6 +44,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (Operation* op : ops) {
|
for (Operation* op : ops) {
|
||||||
|
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||||
|
continue;
|
||||||
|
|
||||||
for (OpOperand& operand : op->getOpOperands()) {
|
for (OpOperand& operand : op->getOpOperands()) {
|
||||||
Value originalValue = operand.get();
|
Value originalValue = operand.get();
|
||||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||||
@@ -76,7 +75,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
int64_t totalBytes = -1;
|
||||||
|
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape())
|
||||||
|
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
||||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||||
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
@@ -105,14 +106,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
copiedValue =
|
||||||
|
pim::PimMemCopyHostToDevOp::create(
|
||||||
rewriter,
|
rewriter,
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
originalType,
|
originalType,
|
||||||
|
getOrCreateHostIndexConstant(op, 0, constantFolder),
|
||||||
|
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
|
||||||
deviceDst,
|
deviceDst,
|
||||||
getGlobalOp.getResult(),
|
getGlobalOp.getResult(),
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
@@ -134,6 +136,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
|||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
IRRewriter rewriter(moduleOp.getContext());
|
IRRewriter rewriter(moduleOp.getContext());
|
||||||
|
OperationFolder constantFolder(moduleOp.getContext());
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
|
|
||||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||||
@@ -141,10 +144,10 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||||
materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
|
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||||
|
|
||||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||||
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
|
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||||
|
|
||||||
SmallVector<Operation*> hostCompactOps;
|
SmallVector<Operation*> hostCompactOps;
|
||||||
for (Operation& op : funcOp.getBody().front())
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -119,11 +120,27 @@ static bool isConstantGlobalView(Value value) {
|
|||||||
|
|
||||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
return operandIndex == 1;
|
return operandIndex == 3;
|
||||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||||
return operandIndex == 1;
|
return operandIndex == 1;
|
||||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return operandIndex == 0;
|
return operandIndex == 2;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isCoreWeightBlockArgument(Value value) {
|
||||||
|
auto blockArgument = dyn_cast<BlockArgument>(value);
|
||||||
|
if (!blockArgument)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(blockArgument.getOwner()->getParentOp()))
|
||||||
|
return static_cast<unsigned>(blockArgument.getArgNumber()) < coreOp.getWeights().size();
|
||||||
|
|
||||||
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(blockArgument.getOwner()->getParentOp())) {
|
||||||
|
unsigned argNumber = static_cast<unsigned>(blockArgument.getArgNumber());
|
||||||
|
return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size();
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +210,10 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
|
|
||||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||||
(void) verifyCoreOperands(coreBatchOp, diagnostics);
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||||
|
(void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
|
||||||
|
return verifyCoreOperands(scalarCore, diagnostics);
|
||||||
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,6 +317,9 @@ private:
|
|||||||
if (!isa<BaseMemRefType>(operand.getType()))
|
if (!isa<BaseMemRefType>(operand.getType()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
if (isCoreWeightBlockArgument(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||||
if (failed(resolvedAddress)) {
|
if (failed(resolvedAddress)) {
|
||||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
@@ -327,6 +350,26 @@ private:
|
|||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||||
|
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
||||||
|
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
||||||
|
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
||||||
|
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py
|
|||||||
|
|
||||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||||
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
||||||
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
|
||||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||||
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||||
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||||
|
|||||||
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
|
|||||||
# GEMM tests
|
# GEMM tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def gemm_simple():
|
||||||
|
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
|
||||||
|
B, K, N = 10, 132, 132
|
||||||
|
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/simple", "gemm_simple.onnx")
|
||||||
|
|
||||||
|
|
||||||
def gemm_non_square():
|
def gemm_non_square():
|
||||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||||
B, K, N = 4, 128, 64
|
B, K, N = 4, 128, 64
|
||||||
@@ -823,6 +835,7 @@ def div_after_gemm():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Generating GEMM tests:")
|
print("Generating GEMM tests:")
|
||||||
|
gemm_simple()
|
||||||
gemm_non_square()
|
gemm_non_square()
|
||||||
gemm_with_bias()
|
gemm_with_bias()
|
||||||
gemm_transB()
|
gemm_transB()
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def main():
|
|||||||
help="Core count to pass to Raptor. Required for PIM validation.")
|
help="Core count to pass to Raptor. Required for PIM validation.")
|
||||||
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
|
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
|
||||||
help="Scheduler used by the Spatial merge-compute-nodes pass.")
|
help="Scheduler used by the Spatial merge-compute-nodes pass.")
|
||||||
ap.add_argument("--command-timeout-seconds", type=float, default=60.0,
|
ap.add_argument("--command-timeout-seconds", type=float, default=1000000.0,
|
||||||
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
|
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
|
||||||
ap.add_argument("--clean", action="store_true",
|
ap.add_argument("--clean", action="store_true",
|
||||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||||
|
|||||||
Reference in New Issue
Block a user