30 Commits

Author SHA1 Message Date
NiccoloN d609e84054 teh only weight (WIP)
Validate Operations / validate-operations (push) Waiting to run
2026-05-26 18:42:14 +02:00
NiccoloN addfc8a86e remove other dead logic
Validate Operations / validate-operations (push) Waiting to run
2026-05-25 21:22:08 +02:00
NiccoloN 0f240af271 cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Waiting to run
2026-05-25 20:58:51 +02:00
ilgeco bdc4ca33f3 No extract no more
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 18:19:43 +02:00
ilgeco b79c333c6c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-25 15:44:40 +02:00
ilgeco eea9261c7b Bye Bye DCP 2026-05-25 15:44:30 +02:00
NiccoloN e8a08f6dd0 faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 15:24:12 +02:00
NiccoloN 4855a2e105 add verification of static weights in spatial
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 12:00:42 +02:00
NiccoloN 3a7a832198 MaterializeMergeSchedule.cpp fix for yolo11_depth_18 2026-05-24 11:54:00 +02:00
NiccoloN 48ca6bd28d speed fix with a simple cache
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 10:52:28 +02:00
NiccoloN f595cc6ffd fix high memory usage in IR 2026-05-24 10:41:47 +02:00
NiccoloN c734f1b37e better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Has been cancelled
add support for other constant-time arith ops in codegen
2026-05-24 10:10:24 +02:00
NiccoloN b79ce8eeaa use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled
run dce at the end of MaterializeMergeSchedule to get rid of unused constants
2026-05-23 14:25:34 +02:00
NiccoloN 76a37e198f better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-23 11:17:36 +02:00
NiccoloN 7f3c7464b4 update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 22:16:19 +02:00
NiccoloN c77ffa9c56 better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
2026-05-22 21:52:28 +02:00
NiccoloN 495186503c fix cmake magic once again 2026-05-22 19:21:56 +02:00
NiccoloN 2c1da813b5 fix much stuff 2026-05-22 18:53:38 +02:00
NiccoloN 8337a11ce9 automatic code reformat 2026-05-22 15:23:48 +02:00
ilgeco d136136d22 Fix add of input in random order for compute_batch
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 15:21:02 +02:00
NiccoloN 074eb183c7 saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 07:27:54 +02:00
NiccoloN 43ed3914b8 better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 06:56:39 +02:00
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:13:54 +02:00
NiccoloN a50e77ff38 refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-20 19:06:41 +02:00
NiccoloN f56c4159b5 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-19 15:01:26 +02:00
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
128 changed files with 9038 additions and 9472 deletions
+92 -24
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir
function(raptor_ensure_symlink link_path target_path)
get_filename_component(link_parent "${link_path}" DIRECTORY)
# Materialize a CMake shim directory
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}")
message(FATAL_ERROR "Directory not found: ${link_parent}")
endif()
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR
"External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
)
endif()
endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
file(CREATE_LINK
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endforeach ()
endfunction()
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
)
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
)
# Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1)
if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}")
return()
endif()
endif ()
# Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1)
if (anchor_pos EQUAL -1)
message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n"
" File : ${file_path}\n"
" Anchor: ${anchor}"
)
endif()
endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}")
+41 -1
View File
@@ -98,7 +98,7 @@ Supporting pieces:
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
- `src/PIM/Pass` — auxiliary passes (`MessagePass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
@@ -145,6 +145,46 @@ validate.py \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Each validation run writes debugging artifacts into the benchmark's workspace
directory (for example `validation/operations/gemm/small/`):
- `inputs/` — generated input CSVs used for the run.
- `outputs/` — reference outputs dumped by the native ONNX runner.
- `raptor/` — compiler artifacts:
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
`raptor/reports/` such as `dcp_merge_report.txt`,
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
- `runner/` — generated reference runner source, build tree, and shared library.
- `simulation/out.bin` — raw simulator output dump used for output comparison.
That means you usually do not need to rerun standalone `--EmitSpatial` or
`--EmitPim` commands while debugging validation failures: the per-pass dialect
dumps are already available under `raptor/dialects/`.
The validator does not currently expose a simulator tracing flag, but once a
validation has produced `raptor/pim/` you can rerun the simulator manually with
tracing enabled:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
```
With `--features tracing`, the simulator writes per-core traces as
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
and the model output shapes. If you need a clean slate before rerunning, use:
```bash
validate.py --clean
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle
.iter()
.map(|c| c.to_string())
.map(format_core)
.collect::<Vec<_>>()
.join(" -> ");
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target)
format!("core {} send {}B -> {}", core - 1, size, target - 1)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source)
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
}
CoreState::Working => format!("core {} working", core),
CoreState::Halted => format!("core {} halted", core),
CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core - 1),
})
})
.collect::<Vec<_>>()
+53
View File
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
set(PIM_GENERATED_PATH_SHIM_TARGET "")
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
function(add_pim_generated_path_shim relative_path)
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
add_custom_command(
OUTPUT "${shim_file}"
DEPENDS "${real_file}"
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
VERBATIM
)
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE pim_generated_path_scan_sources
CONFIGURE_DEPENDS
"${PIM_SRC_ROOT}/*.cpp"
"${PIM_SRC_ROOT}/*.hpp"
)
set(pim_generated_path_shims)
foreach (source_file IN LISTS pim_generated_path_scan_sources)
file(READ "${source_file}" source_contents)
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
foreach (inc_match IN LISTS source_inc_matches)
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
list(APPEND pim_generated_path_shims "${relative_inc_path}")
endforeach ()
endforeach ()
list(REMOVE_DUPLICATES pim_generated_path_shims)
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
add_pim_generated_path_shim("${relative_inc_path}")
endforeach ()
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
endif ()
set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
if (PIM_GENERATED_PATH_SHIM_TARGET)
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
endif ()
endfunction()
add_subdirectory(Dialect)
+2
View File
@@ -1,5 +1,7 @@
add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
+542 -1
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include <limits>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -28,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
return value;
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -55,6 +67,293 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq:
return lhs == rhs;
case mlir::arith::CmpIPredicate::ne:
return lhs != rhs;
case mlir::arith::CmpIPredicate::slt:
return lhs < rhs;
case mlir::arith::CmpIPredicate::sle:
return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt:
return lhs > rhs;
case mlir::arith::CmpIPredicate::sge:
return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult:
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule:
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt:
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge:
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant:
return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add:
return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub:
return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul:
return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI:
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default:
llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -110,6 +409,16 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return mlir::failure();
return *lhs / *rhs;
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
@@ -126,6 +435,34 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
}
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure();
}
@@ -218,7 +555,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
@@ -243,6 +580,191 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
}
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
if (allStatic) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
constantByteOffset +=
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
@@ -251,6 +773,8 @@ llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueK
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
@@ -264,4 +788,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress>
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir
+52
View File
@@ -1,10 +1,14 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
@@ -23,6 +27,51 @@ struct StaticValueKnowledge {
StaticValueKnowledge() {}
};
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress>
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
@@ -35,9 +84,12 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir
+22
View File
@@ -0,0 +1,22 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
llvm::SmallVector<int32_t>
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
} // namespace onnx_mlir
+15
View File
@@ -0,0 +1,15 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t>
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
} // namespace onnx_mlir
+82
View File
@@ -0,0 +1,82 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
}
} // namespace onnx_mlir
+32
View File
@@ -0,0 +1,32 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::RewriterBase& rewriter);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
} // namespace onnx_mlir
+82 -13
View File
@@ -1,25 +1,37 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
if (mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::DivSIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::RemSIOp,
mlir::arith::IndexCastOp,
mlir::arith::CmpIOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op))
return true;
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
return selectOp.getType().isIntOrIndex();
return false;
}
mlir::LogicalResult
@@ -30,6 +42,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
@@ -65,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
return mlir::success(!hasFailure);
}
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
hasFailure = true;
continue;
}
if (*step <= 0) {
forOp.emitOpError("requires positive scf.for step for PIM verification");
hasFailure = true;
continue;
}
llvm::SmallVector<int64_t, 2> samples;
if (*lowerBound < *upperBound) {
samples.push_back(*lowerBound);
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
if (last != *lowerBound)
samples.push_back(last);
}
for (int64_t inductionValue : samples) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
hasFailure = true;
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+9
View File
@@ -21,4 +21,13 @@ walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
/// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined.
mlir::LogicalResult
walkPimCoreBlockStructurally(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)>
callback);
} // namespace onnx_mlir
+25
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements;
}
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
+11
View File
@@ -1,8 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
+50 -23
View File
@@ -19,29 +19,35 @@ void markWeightAlways(mlir::Operation* op) {
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == *weightArg;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
template <typename VMMOpTy, typename ParentOpTy>
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg || *weightArg != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
@@ -54,7 +60,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
@@ -90,19 +96,40 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreOp->getOpOperand(weightIndex));
break;
}
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreBatchOp->getOpOperand(weightIndex));
break;
}
});
});
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
}
return std::nullopt;
}
} // namespace onnx_mlir
+26
View File
@@ -3,9 +3,15 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
@@ -26,4 +32,24 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
auto addWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
} // namespace onnx_mlir
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
+3 -3
View File
@@ -13,7 +13,8 @@
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
-1
View File
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp
PimWeightEmitter.cpp
+1 -1
View File
@@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
-136
View File
@@ -1,136 +0,0 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
+63 -13
View File
@@ -4,13 +4,16 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <limits>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
@@ -23,6 +26,13 @@ struct MemEntry {
size_t size;
};
struct MemoryValueKey {
mlir::Value value;
std::optional<unsigned> lane;
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
@@ -50,33 +60,33 @@ struct MemoryReportEntry {
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value);
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry);
public:
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
MemEntry getMemEntry(const MemoryValueKey& key) const;
};
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
@@ -84,14 +94,21 @@ private:
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries), hostMem(memEntriesMap), fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
size_t getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {},
std::optional<unsigned> lane = std::nullopt) const;
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
@@ -103,15 +120,24 @@ public:
void clean(mlir::Operation* op);
};
struct CoreEmissionJob {
mlir::Operation* coreLikeOp = nullptr;
size_t originalCoreId = 0;
size_t emittedCoreId = 0;
llvm::SmallVector<unsigned, 4> lanes;
std::optional<uint64_t> batchReportId;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
return memory.getValueAddress(value, knowledge, batchLane);
}
size_t remapCoreId(size_t coreId) const;
@@ -141,15 +167,18 @@ public:
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getIndexValue(value, knowledge);
}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
@@ -172,3 +201,24 @@ public:
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() {
return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0};
}
static onnx_mlir::MemoryValueKey getTombstoneKey() {
return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0};
}
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
}
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
};
} // namespace llvm
+8 -15
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
@@ -15,13 +15,12 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
@@ -49,12 +48,6 @@ llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
llvm::cl::init(4000));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
-2
View File
@@ -22,7 +22,6 @@ typedef enum {
typedef enum {
MergeSchedulerPeft = 0,
MergeSchedulerDcp = 1,
} PimMergeSchedulerType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
@@ -36,7 +35,6 @@ extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
-4
View File
@@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
@@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimCodePass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim code emitted"));
}
}
+93 -76
View File
@@ -3,17 +3,19 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <type_traits>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
@@ -126,22 +128,6 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
@@ -163,86 +149,117 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
llvm::DenseMap<mlir::Value, std::string> mapWeightValueToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto processWeight = [&](Operation* ownerOp,
mlir::Value weight,
size_t weightIndex,
size_t coreId) -> LogicalResult {
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
return success();
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) {
mapCoreWeightToFileName[coreId].insert({weight, weightFile->second});
return success();
}
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapWeightValueToFileName[weight] = fileName;
mapCoreWeightToFileName[coreId].insert({weight, fileName});
return success();
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapWeightValueToFileName[weight] = newFileName;
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
return success();
};
auto processCoreLike = [&](auto coreLikeOp) {
auto usedIndices = getUsedWeightIndices(coreLikeOp);
for (unsigned index : usedIndices) {
if (index >= coreLikeOp.getWeights().size()) {
coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range");
}
}
if constexpr (std::is_same_v<std::decay_t<decltype(coreLikeOp)>, pim::PimCoreOp>) {
size_t coreId = static_cast<size_t>(coreLikeOp.getCoreId());
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
else {
auto batchCoreIds = getBatchCoreIds(coreLikeOp);
SmallVector<size_t> orderedCoreIds;
llvm::SmallSet<size_t, 8> seenCoreIds;
for (int32_t coreId : batchCoreIds)
if (seenCoreIds.insert(static_cast<size_t>(coreId)).second)
orderedCoreIds.push_back(static_cast<size_t>(coreId));
for (size_t coreId : orderedCoreIds)
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
(void) processCoreLike(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
(void) processCoreLike(cast<pim::PimCoreBatchOp>(op));
}
return mapCoreWeightToFileName;
}
@@ -4,8 +4,8 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
CompileTime.cpp
ONNXToSpatialVerifier.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Math/Conv.cpp
@@ -18,13 +18,17 @@ namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -93,14 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -123,6 +132,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -131,13 +142,13 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -6,7 +6,7 @@
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
using namespace mlir;
@@ -44,7 +44,7 @@ SmallVector<Value> sliceTensor(
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
if (isCompileTimeComputable(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
@@ -113,7 +113,7 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isHostFoldableValue(scalarToBroadcast))
if (isCompileTimeComputable(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
@@ -7,8 +7,11 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -145,7 +148,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
@@ -156,7 +159,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
@@ -169,7 +172,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
@@ -177,7 +180,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
@@ -185,7 +188,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
@@ -195,62 +198,95 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op)
return std::nullopt;
if (!visited.insert(op).second)
return {
{op, chainLength}
};
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
return {
{op, chainLength}
};
chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
if (!isStaticTensorResult(op))
return false;
return std::nullopt;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
std::optional<CompileTimeSource> res = {};
for (auto operandValue : concatOp.getOperands()) {
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
if (!partialRes)
return std::nullopt;
return false;
if (!res) {
res = partialRes;
continue;
}
if(res->chainLength < partialRes->chainLength){
res = partialRes;
}
}
return res;
}
return std::nullopt;
}
} // namespace
bool isHostFoldableValue(Value value) {
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited);
}
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
return getCompileTimeSourceImpl(definingOp, visited).has_value();
}
bool isHostFoldableOp(Operation* op) {
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
return getCompileTimeSourceImpl(op, visited).has_value();
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
return getHostConstantDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,22 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
struct CompileTimeSource {
mlir::Operation* source;
size_t chainLength;
};
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
bool isCompileTimeComputable(mlir::Value value);
bool isCompileTimeOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,34 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
}
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -12,13 +13,12 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -44,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
@@ -91,7 +92,7 @@ static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
@@ -117,6 +118,7 @@ void ONNXToSpatialPass::runOnOperation() {
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
@@ -155,6 +157,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
@@ -188,18 +191,9 @@ void ONNXToSpatialPass::runOnOperation() {
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure();
return;
}
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
@@ -212,6 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
@@ -229,7 +224,7 @@ void ONNXToSpatialPass::runOnOperation() {
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
@@ -0,0 +1,56 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
return;
for (Value result : op->getResults()) {
if (hasOnlySpatialMvmVmmWeightUses(result))
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
});
return;
}
});
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -5,6 +5,6 @@
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -11,7 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -391,7 +391,7 @@ static Value lowerSingleConvGroup(Value x,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
auto wDenseAttr = getHostConstDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -412,7 +412,7 @@ static Value lowerSingleConvGroup(Value x,
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasDenseAttr = getHostConstDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -717,7 +717,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
if (llvm::all_of(groupResults, isCompileTimeComputable)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
File diff suppressed because it is too large Load Diff
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
static Value
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
@@ -62,7 +55,7 @@ static Value collapseBatchDims(Value value,
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
static Value
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
@@ -126,7 +114,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildMatrix(value);
auto batchMatrixCompute =
@@ -154,7 +142,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
@@ -194,7 +182,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewr
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
@@ -247,7 +235,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -91,7 +91,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewr
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
@@ -135,7 +135,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
}
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
if (isCompileTimeComputable(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) {
if (llvm::all_of(inputs, isCompileTimeComputable)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
@@ -5,7 +5,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -115,7 +115,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
}
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
if (isCompileTimeComputable(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
return rewriter.notifyMatchFailure(resizeOp,
"resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -61,7 +61,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
sliceSizes.push_back(resultType.getShape()[axis]);
}
if (isHostFoldableValue(adaptor.getInput())) {
if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize;
@@ -27,66 +27,26 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= block.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
return true;
}
return false;
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
@@ -96,11 +56,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
@@ -131,8 +89,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -141,17 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
mapper.map(*oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
@@ -180,11 +159,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
@@ -220,8 +197,31 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(outputArg->getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -230,31 +230,45 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto newLaneArg = newCompute.getLaneArgument();
if (!newLaneArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
mapper.map(*oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
}
for (Operation& op : oldBlock.without_terminator())
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
@@ -262,10 +276,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
@@ -277,8 +287,6 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
@@ -7,14 +7,10 @@
namespace onnx_mlir {
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
@@ -1,11 +1,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -15,7 +18,17 @@ using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
});
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
@@ -28,54 +41,75 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
if (!returnOp)
return failure();
return result.getUses().begin()->getOperandNumber();
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset =
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
}
if (!totalOffset)
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
return totalOffset;
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
@@ -91,9 +125,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
returnOperandIndices[resultIndex] = *returnOperandIndex;
}
}
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
@@ -102,7 +149,21 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto oldLaneArg = computeBatchOp.getLaneArgument();
if (!oldLaneArg)
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
if (!oldWeightArg)
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
}
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
if (!oldArg)
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -114,7 +175,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
mapper.map(*oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
@@ -136,41 +197,52 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
return copied;
};
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
Value& hostOutputTensor = hostOutputTensors[resultIndex];
if (hostOutputTensor)
return hostOutputTensor;
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
return hostOutputTensor;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
if (!firstOutputArg)
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
for (Operation& nestedOp : parallelOp.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
if (!insertSlice)
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &oldBlock)
return insertSlice.emitOpError("expected compute_batch output block argument destination");
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource));
}
continue;
}
@@ -178,6 +250,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
mapper.map(toTensorOp.getResult(), clonedTensor);
continue;
}
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -194,9 +270,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
}
}
for (Value operand : op.getOperands()) {
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
if (isExplicitHostOperand(&op, operandIndex))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
@@ -1,10 +0,0 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
@@ -22,6 +21,8 @@ add_pim_library(OMSpatialToPim
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
@@ -1,5 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -10,17 +9,12 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
pim::PimSendOp::create(
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
rewriter.eraseOp(op);
return success();
}
@@ -42,47 +36,13 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
op.getSourceCoreId())
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
@@ -125,12 +85,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
@@ -20,8 +20,6 @@ namespace onnx_mlir {
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T>
@@ -1,3 +1,5 @@
#include <cassert>
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -29,7 +31,18 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
BlockArgument bodyArgument;
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg;
}
else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg;
}
unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
@@ -37,7 +50,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner);
}
@@ -3,11 +3,12 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -46,8 +53,6 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
@@ -92,7 +97,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
@@ -101,7 +108,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
if (block.getNumArguments() != computeOp.getWeights().size())
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
@@ -110,8 +117,14 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
auto weightArg = computeOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -125,15 +138,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
return success();
SmallVector<Operation*> helperChain;
@@ -143,21 +153,25 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
auto blockArg = computeOp.getInputArgument(inputIndex);
if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during lowering");
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (receiveOp && !blockArg->use_empty()) {
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
Value received =
PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
.getOutput();
blockArg->replaceAllUsesWith(received);
markOpToRemove(receiveOp);
continue;
}
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -167,9 +181,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
@@ -193,15 +206,40 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto coreOp = PimCoreOp::create(
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
auto blockArg = computeOp.getInputArgument(inputIndex);
if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during input materialization");
if (blockArg->use_empty())
continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
auto copied =
PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
outputBuffer,
input,
getTensorSizeInBytesAttr(rewriter, input))
.getOutput();
blockArg->replaceAllUsesWith(copied);
}
if (!computeOp.getInputs().empty())
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
@@ -1,21 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -76,10 +76,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty())
if (BBArgValue->use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
@@ -89,16 +90,17 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty())
if (BBArgValue->use_empty())
continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
@@ -108,7 +110,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
}
else {
{
@@ -143,170 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
};
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
@@ -383,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -6,10 +6,12 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/FoldUtils.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -42,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
pim::PimTransposeOp>(op);
}
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
@@ -318,7 +315,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -327,7 +325,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -337,15 +340,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
static void cloneHelperChain(Value sourceValue,
ArrayRef<Operation*> helperChain,
IRRewriter& rewriter,
OperationFolder& constantFolder,
Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -360,23 +366,26 @@ static Value emitHostCopy(IRRewriter& rewriter,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
int32_t sizeInBytes,
OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
hostTargetOffsetValue,
deviceSourceOffsetValue,
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
} // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp,
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue;
@@ -411,70 +420,85 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
Location loc = producerOp->getLoc();
OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
if (auto returnUse = analyzeReturnUse(producedValue)) {
Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
markOpToRemove(op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
currentStoredValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
auto resultUses = producedValue.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
storedValue,
0,
0,
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
storedValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
producerOp->emitOpError(
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
@@ -484,7 +508,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
@@ -503,7 +527,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
@@ -513,7 +537,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
static_cast<int32_t>(elementSize),
constantFolder);
}
return ReturnPathLoweringResult::Handled;
}
@@ -521,7 +546,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
return ReturnPathLoweringResult::NotReturnPath;
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
}
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
@@ -538,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(state, op);
markOpToRemove(op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp);
markOpToRemove(computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
@@ -552,24 +582,31 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
markOpToRemove(receiveOp);
return;
}
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
@@ -578,7 +615,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
@@ -1,37 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include <functional>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
struct ReturnPathState {
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
mlir::IRRewriter& rewriter,
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
>;
def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(SpatVMMOp:$srcOpRes $weight, $vector),
(PimVMMOp $weight, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
@@ -1,8 +1,11 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@@ -12,6 +15,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -21,54 +25,28 @@
#include <cassert>
#include <utility>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "Pass/PIMPasses.h"
#include "SpatialToPimPass.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
};
} // namespace
} // namespace raptor
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
@@ -104,23 +82,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
IntegerAttr {});
}
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
return PimMemCopyHostToDevBatchOp::create(
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
return PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
tensorType,
outputBuffer,
zeroValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
sizeAttr)
.getOutput();
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
.getOutput();
}
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
static Value
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector");
@@ -131,14 +120,16 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
}
void SpatialToPimPass::runOnOperation() {
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0;
outputTensors.clear();
operationsToRemove.clear();
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
@@ -151,9 +142,11 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
target.addLegalDialect<affine::AffineDialect,
PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
@@ -163,11 +156,7 @@ void SpatialToPimPass::runOnOperation() {
BuiltinDialect>();
target.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet initialPatterns(ctx);
@@ -181,19 +170,18 @@ void SpatialToPimPass::runOnOperation() {
RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure();
return;
@@ -202,7 +190,7 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure();
return;
@@ -228,12 +216,27 @@ void SpatialToPimPass::runOnOperation() {
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
populateAffineToStdConversionPatterns(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
ConversionTarget coreBodyTarget(*ctx);
coreBodyTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelSendOp,
spatial::SpatExtractRowsOp>();
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
if (failed(applyFullConversion(coreOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) {
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
signalPassFailure();
return;
@@ -243,7 +246,7 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) {
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
signalPassFailure();
return;
@@ -251,15 +254,8 @@ void SpatialToPimPass::runOnOperation() {
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
replaceReturnWithOutputBuffers(returnOp, rewriter);
eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
@@ -278,9 +274,7 @@ void SpatialToPimPass::runOnOperation() {
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
@@ -301,7 +295,8 @@ void SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -309,7 +304,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp);
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
auto paddedOutputType = RankedTensorType::get(
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
@@ -334,13 +329,17 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
});
}
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
if (!hasByteSizedElementType(elementType))
return;
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -349,10 +348,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
rewriter,
loc,
tensorType,
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
deviceTensor,
inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
@@ -374,11 +374,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
return success();
}
void SpatialToPimPass::markOpToRemove(Operation* op) {
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
void raptor::SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
} // namespace onnx_mlir
@@ -0,0 +1,72 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace raptor {
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
llvm::SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
mlir::Value producedValue,
mlir::Value storedValue,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
void markOpToRemove(mlir::Operation* op);
void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
};
} // namespace raptor
} // namespace onnx_mlir
+34 -128
View File
@@ -2,6 +2,7 @@
#define PIM_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -24,7 +25,8 @@ def PimTensor :
// Execution
//===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
def PimCoreOp : PimOp<"core", [SingleBlock,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body);
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
I32Attr:$coreId
);
let assemblyFormat = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute equivalent batched core bodies";
let regions = (region SizedRegion<1>:$body);
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
Variadic<PimTensor>:$inputs
);
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
@@ -61,16 +74,6 @@ def PimHaltOp : PimOp<"halt", [Terminator]> {
}];
}
def PimYieldOp : PimOp<"yield", [Terminator]> {
let summary = "Yield results from a Pim region";
let arguments = (ins
Variadic<PimTensor>:$outputs
);
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
@@ -81,57 +84,21 @@ def PimSendOp : PimOp<"send", []> {
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
I32Attr:$targetCoreId
Index:$targetCoreId
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
}];
}
def PimSendTensorOp : PimOp<"send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor to target cores";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
DenseI32ArrayAttr:$targetCoreIds
);
let hasCustomAssemblyFormat = 1;
}
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
I32Attr:$sourceCoreId
Index:$sourceCoreId
);
let results = (outs
@@ -145,84 +112,18 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
}];
}
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks from source cores into one tensor";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasCustomAssemblyFormat = 1;
}
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
let arguments = (ins
Index:$deviceTargetOffset,
Index:$hostSourceOffset,
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
@@ -237,7 +138,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
`(` $deviceTarget `,` $hostSource `)` attr-dict
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
}];
}
@@ -271,10 +174,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory";
let arguments = (ins
Index:$hostTargetOffset,
Index:$deviceSourceOffset,
PimTensor:$hostTarget,
PimTensor:$deviceSource,
I32Attr:$hostTargetOffset,
I32Attr:$deviceSourceOffset,
I32Attr:$size
);
@@ -289,7 +192,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
`(` $hostTarget `,` $deviceSource `)` attr-dict
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
}];
}
@@ -374,7 +279,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let summary = "Vector-matrix multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$weight,
PimTensor:$input,
PimTensor:$outputBuffer
);
@@ -391,7 +296,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
type($outputBuffer) `)` `->` type($output)
}];
}
+33
View File
@@ -1,8 +1,41 @@
#include <string>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
}
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
+150 -262
View File
@@ -20,28 +20,128 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren: return parser.parseRParen();
case ListDelimiter::Square: return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
printer << " " << keyword << " ";
printCompressedIntegerList(printer, coreIds);
}
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
return parseCompressedIntegerList(parser, coreIds);
}
} // namespace
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
void PimCoreOp::print(OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printer << " ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " coreId " << getCoreId();
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<Type> weightTypes;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (hasCoreId && result.attributes.get("coreId"))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
return failure();
if (hasCoreId)
result.addAttribute("coreId", getI32Attr(parser, coreId));
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
return parser.parseRegion(*body, weightArgs);
}
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount() << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
@@ -49,51 +149,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
printer << " -> ()";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Region* body = result.addRegion();
if (parser.parseRegion(*body))
return failure();
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|| parser.parseLParen() || parser.parseRParen())
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
@@ -110,233 +216,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
return failure();
}
return success();
}
void PimYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
Region* body = result.addRegion();
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
void PimConcatOp::print(OpAsmPrinter& printer) {
+93 -87
View File
@@ -1,9 +1,14 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -14,6 +19,63 @@ namespace pim {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion();
Operation* definingOp = value.getDefiningOp();
return definingOp ? definingOp->getParentRegion() : nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
if (!getGlobalOp)
return false;
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
@@ -28,97 +90,41 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static LogicalResult
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return op->emitError() << kind << " must be nested inside pim.core_batch";
int32_t laneCount = coreBatchOp.getLaneCount();
if (laneCount <= 0)
return op->emitError() << kind << " requires a positive parent laneCount";
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
if (weightIndex >= coreOp.getWeights().size())
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
} // namespace
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
LogicalResult PimSendTensorBatchOp::verify() {
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveTensorBatchOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorBatchCommunication(
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
LogicalResult PimCoreBatchOp::verify() {
if (getLaneCount() <= 0)
return emitError("laneCount must be positive");
Block& block = getBody().front();
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimVMMOp::verify() {
@@ -126,9 +132,9 @@ LogicalResult PimVMMOp::verify() {
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
return emitError("weight must be a shaped value");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
return PimMemCopyOp::create(rewriter,
loc,
@@ -1,9 +1,10 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
return builder.getI32IntegerAttr(sizeInBytes);
}
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceTargetMemRef.getType(),
memCopyHostToDevOp.getDeviceTargetOffset(),
memCopyHostToDevOp.getHostSourceOffset(),
deviceTargetMemRef,
hostSourceMemRef,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
hostTargetMemRef.getType(),
memCopyDevToHostOp.getHostTargetOffset(),
memCopyDevToHostOp.getDeviceSourceOffset(),
hostTargetMemRef,
deviceSourceMemRef,
memCopyDevToHostOp.getHostTargetOffsetAttr(),
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
@@ -151,78 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
return success();
}
};
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
replaceOpWithNewBufferizedOp<PimReceiveOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
return success();
}
};
@@ -256,30 +186,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
}
};
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -302,59 +208,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdAttr());
return success();
}
};
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendTensorBatchOpInterface
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
sendOp.getTargetCoreId());
return success();
}
};
@@ -368,6 +222,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
return {};
}
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return {};
unsigned weightIndex = bbArg.getArgNumber();
return {
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return failure();
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
return memRefType;
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
@@ -375,7 +260,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
auto coreOp = cast<PimCoreOp>(op);
bool alreadyBufferized =
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized)
return success();
@@ -420,9 +308,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return {};
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
unsigned argNumber = bbArg.getArgNumber();
if (argNumber == 0)
return {};
unsigned weightCount = coreBatchOp.getWeights().size();
unsigned operandIndex = argNumber - 1;
if (argNumber > weightCount + 1)
operandIndex = weightCount + (argNumber - 1 - weightCount);
return {
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
};
}
@@ -438,11 +334,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return failure();
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
unsigned argNumber = bbArg.getArgNumber();
if (argNumber == 0)
return failure();
Value tiedOperand;
unsigned weightCount = coreBatchOp.getWeights().size();
if (argNumber <= weightCount)
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
else
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
return memRefType;
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op,
@@ -454,8 +360,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
bool alreadyBufferized =
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized)
return success();
@@ -553,6 +460,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
if (failed(weightOpt))
return failure();
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
@@ -564,7 +475,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
return success();
}
};
@@ -646,13 +557,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
@@ -9,6 +9,7 @@
#include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir;
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
}
static bool isCandidateAllocType(MemRefType type) {
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
return type && type.hasStaticShape() && type.getLayout().isIdentity()
&& hasByteSizedElementType(type.getElementType());
}
static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
}
static FailureOr<uint64_t>
@@ -50,10 +52,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value)
pendingValues.push_back(forOp.getResult(index));
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
@@ -8,6 +8,7 @@
#include <fstream>
#include "Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
@@ -47,12 +48,6 @@ struct CoalescingReportEntry {
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)},
@@ -72,7 +67,7 @@ static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
}
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
std::fstream file = openReportFile("static_memory_coalescing_report");
std::fstream file = openReportFile("memory_coalescing_report");
if (!file.is_open())
return;
-9
View File
@@ -2,25 +2,16 @@ add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
Channels.cpp
SpatialOps.cpp
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
EXCLUDE_FROM_OM_LIBS
-120
View File
@@ -1,120 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
using namespace mlir;
namespace onnx_mlir::spatial {
namespace {
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive)
return failure();
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
return failure();
}
return success();
}
} // namespace
Channels::Channels(func::FuncOp funcOp) {
if (!funcOp)
return;
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
}
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].send = sendOp;
}
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].receive = receiveOp;
}
void Channels::eraseSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
auto it = endpoints.find(channelId);
if (it == endpoints.end())
return;
it->second.send = {};
if (!it->second.receive)
endpoints.erase(it);
}
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
auto it = endpoints.find(channelId);
if (it == endpoints.end())
return;
it->second.receive = {};
if (!it->second.send)
endpoints.erase(it);
}
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
auto it = endpoints.find(id);
if (it == endpoints.end())
return failure();
return it->second;
}
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
auto endpointsOr = lookup(getChannelId(sendOp));
if (failed(endpointsOr) || !endpointsOr->receive)
return failure();
return endpointsOr->receive;
}
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
auto endpointsOr = lookup(getChannelId(receiveOp));
if (failed(endpointsOr) || !endpointsOr->send)
return failure();
return endpointsOr->send;
}
LogicalResult Channels::verify() const {
for (const auto& [channelId, pair] : endpoints) {
if (!pair.send || !pair.receive) {
if (pair.send) {
auto sendOp = pair.send;
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
}
else if (pair.receive) {
auto receiveOp = pair.receive;
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
}
return failure();
}
if (failed(verifyEndpointPair(pair)))
return failure();
}
return success();
}
} // namespace onnx_mlir::spatial
-43
View File
@@ -1,43 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir::spatial {
struct ChannelEndpoints {
SpatChannelSendOp send;
SpatChannelReceiveOp receive;
};
class Channels {
public:
using ChannelId = int64_t;
explicit Channels(mlir::func::FuncOp funcOp);
ChannelId allocate();
void insertSend(SpatChannelSendOp sendOp);
void insertReceive(SpatChannelReceiveOp receiveOp);
void eraseSend(SpatChannelSendOp sendOp);
void eraseReceive(SpatChannelReceiveOp receiveOp);
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
mlir::LogicalResult verify() const;
private:
ChannelId nextChannelId = 0;
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
};
} // namespace onnx_mlir::spatial
+70 -141
View File
@@ -2,8 +2,12 @@
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def SpatialDialect : Dialect {
let name = "spat";
@@ -22,7 +26,9 @@ def SpatTensor :
// Execution
//===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
def SpatCompute : SpatOp<"compute",
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
@@ -36,14 +42,27 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatComputeBatch : SpatOp<"compute_batch",
[SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compressed batch of independent equivalent compute lanes";
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
let arguments = (ins
I32Attr:$laneCount,
@@ -57,10 +76,48 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch";
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins)>,
];
let extraClassDeclaration = [{
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
@@ -110,14 +167,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a logical channel";
let arguments = (ins
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId,
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId,
SpatTensor:$input
);
let assemblyFormat = [{
$input attr-dict `:` type($input)
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
}];
}
@@ -125,9 +182,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a logical channel";
let arguments = (ins
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId
);
let results = (outs
@@ -135,103 +192,10 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
);
let assemblyFormat = [{
attr-dict `:` type($output)
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
}];
}
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
@@ -240,7 +204,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$weight,
SpatTensor:$input
);
@@ -251,26 +215,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}];
}
@@ -310,22 +255,6 @@ def SpatVMulOp : SpatOp<"vmul", []> {
}];
}
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVAvgOp : SpatOp<"vavg", []> {
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
+230
View File
@@ -1,9 +1,239 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include <string>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
std::optional<BlockArgument> getBlockArgument(Region& body, unsigned argIdx) {
if (body.empty())
return std::nullopt;
Block& block = body.front();
if (argIdx >= block.getNumArguments())
return std::nullopt;
return block.getArgument(argIdx);
}
std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx, Type type, Location loc) {
if (body.empty())
return std::nullopt;
return body.insertArgument(argIdx, type, loc);
}
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
CrossbarWeightSet collectCrossbarWeights(Region& body) {
CrossbarWeightSet weights;
body.walk([&](SpatVMMOp vmmOp) {
Value weight = vmmOp.getWeight();
weights.insert(weight);
});
return weights;
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), getWeights().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
}
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) {
setNameFn(*outputArg, "out");
continue;
}
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
}
}
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpBuilder::InsertionGuard guard(builder);
Region* bodyRegion = result.addRegion();
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
void SpatialDialect::initialize() {
addTypes<
+8
View File
@@ -5,10 +5,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include <map>
#include <optional>
#include <string>
#include <tuple>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
+191 -217
View File
@@ -10,6 +10,7 @@
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp"
using namespace mlir;
@@ -23,22 +24,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
printer << " channels ";
printCompressedIntegerList(printer, channelIds);
printer << " from ";
printCompressedIntegerList(printer, sourceCoreIds);
printer << " to ";
printCompressedIntegerList(printer, targetCoreIds);
}
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
return parser.getBuilder().getDenseI64ArrayAttr(values);
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
@@ -47,94 +32,86 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
template <typename TensorSendOpTy>
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
printer << " ";
printer.printOperand(op.getInput());
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
template <typename TensorReceiveOpTy>
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getOutput().getType());
}
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
return parser.parseRParen();
}
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
ArrayRef<Type> weightTypes,
ArrayRef<Type> outputTypes,
OpAsmParser::Argument& laneArg,
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
Builder& builder) {
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
applyArgumentTypes(outputTypes, outputArgs);
llvm::append_range(regionArgs, inputArgs);
llvm::append_range(regionArgs, outputArgs);
}
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
result.addTypes(outputType);
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren: return parser.parseRParen();
case ListDelimiter::Square: return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
@@ -242,10 +219,28 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
}
void SpatCompute::print(OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
@@ -264,23 +259,31 @@ void SpatCompute::print(OpAsmPrinter& printer) {
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseArgumentBindings(parser, regionArgs, inputs))
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
@@ -292,9 +295,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size())
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
@@ -313,19 +318,59 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs);
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
auto laneArg = getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
if (getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
@@ -337,10 +382,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
@@ -350,27 +392,45 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (parseArgumentBindings(parser, regionArgs, inputs))
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
@@ -381,10 +441,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size())
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
@@ -403,119 +468,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs);
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
void SpatInParallelOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
printer.printOptionalAttrDict((*this)->getAttrs());
}
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
auto& builder = parser.getBuilder();
std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<OpAsmParser::Argument, 4> regionArgs;
if (parser.parseRegion(*region, regionArgs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutput().getType());
}
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
if (region->empty())
OpBuilder(builder.getContext()).createBlock(region.get());
result.addRegion(std::move(region));
return parser.parseOptionalAttrDict(result.attributes);
}
} // namespace spatial
+250 -259
View File
@@ -1,6 +1,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
@@ -10,6 +15,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -20,210 +26,204 @@ namespace spatial {
namespace {
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error.
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto batchOp = op->getParentOfType<SpatComputeBatch>();
if (!batchOp)
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return batchOp.getLaneCount();
return shapedType.getShape();
}
static LogicalResult verifyTensorChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
return false;
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
unsigned argNumber = blockArg.getArgNumber();
auto firstOutputArg = batchOp.getOutputArgument(0);
if (!firstOutputArg)
return false;
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
}
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
template <typename ComputeOpTy>
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
for (Value weight : computeOp.getWeights())
if (!isCompileTimeComputable(weight))
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
return success();
}
static LogicalResult verifyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
static bool isConstantIndexLike(Value value) {
APInt constantValue;
return matchPattern(value, m_ConstantInt(&constantValue));
}
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
static bool isSupportedLaneAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0)
return false;
if (!llvm::all_of(affineApply.getMapOperands(),
[&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) {
return false;
}
return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0));
}
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
if (extractOp) {
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
return false;
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
}
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
return false;
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
}
static LogicalResult
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (Value offset : offsets)
if (!isSupportedLaneOffsetExpr(offset, laneArg))
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
return success();
}
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
BlockArgument laneArg,
StringRef kind) {
RankedTensorType sourceType = sliceOp.getSourceType();
RankedTensorType destType = sliceOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
auto offsets = sliceOp.getOffsets();
for (Value offset : offsets)
if (!isSupportedLaneOffsetExpr(offset, laneArg))
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return op->emitError("body must terminate with spat.yield");
if (outputTypes.empty()) {
static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results");
return batchOp.emitError("resultless compute_batch body yield must be empty");
}
else {
if (yieldOp.getNumOperands() != 1)
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
}
auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return batchOp.emitError("compute_batch body must have a lane block argument");
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
return failure();
}
return success();
}
} // namespace
LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -347,13 +347,28 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute" );
return op->emitError("ComputeResult used directly inside another Compute");
}
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -386,93 +401,24 @@ LogicalResult SpatCompute::verify() {
}
}
for (auto arg : block.getArguments())
if (arg.use_empty())
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(*this, "compute")))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return success();
}
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor_batch");
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
@@ -482,27 +428,72 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be distinct");
return emitError("compute_batch coreIds values must be unique");
}
Block& block = getBody().front();
if (getInputs().empty()) {
if (block.getNumArguments() != 0)
return emitError("compute_batch body must have no block arguments when there are no inputs");
if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
else {
if (block.getNumArguments() != 1)
return emitError("compute_batch body must have exactly one block argument");
if (block.getArgument(0).getType() != getInputs()[0].getType())
return emitError("body block argument type must match input type");
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
auto blockArg = getOutputArgument(resultIndex);
if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
}
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
if (failed(verifyStaticWeights(*this, "compute_batch")))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
return failure();
return verifyBatchBody(*this, block);
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
return success();
}
} // namespace spatial
@@ -1,20 +0,0 @@
#include "DCPAnalysis.hpp"
#include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
namespace spatial {
DCPAnalysisResult DCPAnalysis::run() {
ComputeGraph graph = buildComputeGraph(entryOp);
DcpScheduleOptions options;
if (coresCount.getValue() > 0)
options.processorCount = static_cast<size_t>(coresCount.getValue());
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
options.allowFallbackForAutoCoreCount = true;
return runDcpScheduler(graph, options, entryOp->getContext());
}
} // namespace spatial
} // namespace onnx_mlir
@@ -1,28 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "../Scheduling/MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation *entryOp;
DCPAnalysisResult run();
public:
DCPAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
DCPAnalysisResult &getResult() { return result; }
};
} // namespace spatial
} // namespace onnx_mlir
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
File diff suppressed because it is too large Load Diff
@@ -1,178 +0,0 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <list>
#include <optional>
#include <unordered_map>
#include <vector>
#include "DCPAnalysis.hpp"
#include "Task.hpp"
#include "Utils.hpp"
namespace mlir {
class MLIRContext;
} // namespace mlir
std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling = false);
void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling = false);
Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
class GraphDCP {
public:
struct CandidateRelations {
llvm::DenseSet<TaskDCP*> ancestors;
llvm::DenseSet<TaskDCP*> descendants;
// descendants ordered by position in the graph's topological order;
// iterating this avoids walking non-descendant tail tasks on hot paths.
llvm::SmallVector<TaskDCP*, 32> descendantsTopoOrder;
};
struct ScheduledTaskInfo {
size_t nodeIndex;
Time aest;
Time alst;
Weight weight;
};
private:
using CpuTaskList = std::list<TaskDCP*>;
struct FindSlot {
Time aest;
int index;
};
std::vector<TaskDCP> nodes;
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
std::vector<uint64_t> taskStructureHashes;
std::vector<CpuTaskList> cpuTasks;
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
llvm::DenseMap<CPU, uint64_t> cpuStructureHashes;
CPU lastCpu = 0;
long long flag = 1;
Time dcpl = 0;
Time maxCompletion = 0;
Time secondMaxCompletion = 0;
TaskDCP* maxCompletionTask = nullptr;
int maxCpuCount = 1000;
mlir::MLIRContext* context = nullptr;
TaskInsertion insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position);
void removeTaskFromCPU(CPU cpu, TaskDCP* task);
CpuTaskList& getOrCreateCpuTasks(CPU cpu);
const CpuTaskList* findCpuTasks(CPU cpu) const;
std::vector<TaskDCP*> getRoots();
long long getUniqueFlag() { return flag++; }
void initAest();
void initAlst();
void initTaskStructureHashes();
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
Time getDcpl() const { return dcpl; }
Time computeTaskAlstOnCpu(TaskDCP* task, CPU cpu, Time scheduleDcpl);
void updateAestFromTask(TaskDCP* task);
void updateAestFromTaskWithDescendants(TaskDCP* task, const llvm::DenseSet<TaskDCP*>& descendants);
void updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder);
// Propagates AEST like the overload above but returns early (before touching
// the remaining descendants) as soon as a task's completion exceeds
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
// Returns true iff the full propagation completed without exceeding the
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
bool tryUpdateAestWithinBudget(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder, Time dcplBudget);
// Incrementally refreshes ALST after `task` has been scheduled. Nodes
// outside the backward cone (`relations.ancestors` plus `task`) retain
// their relative distance to the sink boundary and only absorb the signed
// DCPL delta (`newDcpl - oldDcpl`). `task` itself and its ancestors are
// recomputed in reverse topological order so that new same-CPU transfer
// costs (now zero) and scheduling-edge children are reflected.
void updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl);
void initTopological();
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
size_t getNodeIndex(const TaskDCP* task) const;
// Returns a compact dedup key for CPU `c` when evaluating `candidate`:
// mixes candidateAest, crossbar usage, and the incremental cpu structure
// hash into a single uint64_t. Zero heap allocation.
uint64_t computeCpuCandidateKey(Time candidateAest, CPU cpu);
CandidateRelations selectProcessor(TaskDCP* candidate, bool push);
CPU getLastCpu() const { return lastCpu; }
void incrementLastCpu() { lastCpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
FindSlot findSlotWithFixedFinalTime(
TaskDCP* candidate, CPU cpu, const CandidateRelations& relations, Time finalTime, Time aestOnCpu);
void dumpDot();
friend TaskInsertion;
friend class TaskDCP;
CrossbarUsage getCpuCrossbarUsage(CPU cpu) const;
CrossbarUsage getCpuCrossbarCapacity() const;
CrossbarUsage getTaskCrossbarFootprint(const TaskDCP* task) const;
void reserveTaskCrossbars(CPU cpu, const TaskDCP* task);
void releaseTaskCrossbars(CPU cpu, const TaskDCP* task);
bool wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const;
public:
void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes, llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatCompute : spatComputes)
nodes.emplace_back(spatCompute);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
llvm::ArrayRef<IndexedEdge> edges,
llvm::ArrayRef<int64_t> nodeOrderKeys = {},
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
: nodes(), cpuTasks(), cpuCrossbarUsage() {
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
&& "synthetic crossbar usage must match synthetic node weights");
assert((nodeOrderKeys.empty() || nodeOrderKeys.size() == nodeWeights.size())
&& "synthetic node order keys must match synthetic node weights");
nodes.reserve(nodeWeights.size());
for (auto [index, weight] : llvm::enumerate(nodeWeights))
nodes.emplace_back(nodeOrderKeys.empty() ? static_cast<int64_t>(index) : nodeOrderKeys[index],
weight,
nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
DCPAnalysisResult getResult();
std::vector<ScheduledTaskInfo> getScheduledTasks(CPU cpu) const;
CPU cpuCount() const { return lastCpu; }
void makeEdge(size_t parentIndex, size_t childIndex, Weight weight) {
addEdge(&nodes[parentIndex], &nodes[childIndex], weight);
}
size_t taskInCpu(CPU cpu) { return getOrCreateCpuTasks(cpu).size(); }
void setMaxCpuCount(int value) { maxCpuCount = value; }
int getMaxCpuCount() const { return maxCpuCount; }
// Total crossbar units allocated across all active CPUs.
size_t crossbarsUsed() const;
// Maximum crossbar units available across all active CPUs (lastCpu * per-CPU capacity).
size_t crossbarsAvailable() const;
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
// null the scheduler runs single-threaded (tests use this path).
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
};
@@ -1,154 +0,0 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <fstream>
#include <string>
#include "GraphDebug.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace dcp_graph {
#ifdef DCP_DEBUG_ENABLED
DcpProgressLogger::DcpProgressLogger(size_t totalTasks)
: logProgress(totalTasks >= 200),
totalTasks(totalTasks),
startTime(std::chrono::steady_clock::now()),
lastProgressPrint(startTime) {}
std::string DcpProgressLogger::formatDuration(double seconds) {
if (seconds < 0)
seconds = 0;
long totalSeconds = static_cast<long>(seconds + 0.5);
long hours = totalSeconds / 3600;
long minutes = (totalSeconds % 3600) / 60;
long secs = totalSeconds % 60;
if (hours > 0)
return llvm::formatv("{0}:{1:02}:{2:02}", hours, minutes, secs).str();
return llvm::formatv("{0}:{1:02}", minutes, secs).str();
}
void DcpProgressLogger::recordFindDuration(double seconds) { findCandidateSeconds += seconds; }
void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSeconds += seconds; }
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
if (!logProgress)
return;
llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
totalTasks,
readyCount,
maxCpuCount,
xbarsCapacity);
}
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
double elapsedSeconds,
size_t readyCount,
CPU cpuCount) const {
if (!logProgress || elapsedSeconds < 1.0)
return;
llvm::errs() << llvm::formatv("[DCP] slow node={0} elapsed={1} ready={2} cpus={3}\n",
nodeIndex,
formatDuration(elapsedSeconds),
readyCount,
cpuCount);
}
void DcpProgressLogger::printProgress(
size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force) {
if (!logProgress)
return;
auto now = std::chrono::steady_clock::now();
if (!force && now - lastProgressPrint < std::chrono::seconds(1) && completedTasks != totalTasks)
return;
double elapsedSeconds = std::chrono::duration<double>(now - startTime).count();
double rate = elapsedSeconds > 0.0 ? static_cast<double>(completedTasks) / elapsedSeconds : 0.0;
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
bool done = completedTasks == totalTasks;
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
maxCpuCount,
xbarsUsed,
xbarsAvailable,
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
lastProgressPrint = now;
}
#else
DcpProgressLogger::DcpProgressLogger(size_t) {}
void DcpProgressLogger::recordFindDuration(double) {}
void DcpProgressLogger::recordSelectDuration(double) {}
void DcpProgressLogger::recordUpdateDuration(double) {}
void DcpProgressLogger::advanceCompleted(size_t) {}
void DcpProgressLogger::printStart(size_t, int, size_t) const {}
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
#endif
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu) {
static int dumpIndex = 0;
std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty())
return;
std::string graphDir = outputDir + "/dcp_graph";
onnx_mlir::createDirectory(graphDir);
std::fstream file(graphDir + "/graph_" + std::to_string(dumpIndex++) + ".dot", std::ios::out);
file << "digraph G {\n";
if (!cpuTasks.empty()) {
for (CPU cpu = 0; cpu < lastCpu; cpu++) {
file << "subgraph cluster_" << cpu << "{\nstyle=filled;\ncolor=lightgrey;\n";
size_t cpuIndex = static_cast<size_t>(cpu);
if (cpuIndex >= cpuTasks.size()) {
file << " }\n";
continue;
}
for (auto node : cpuTasks[cpuIndex]) {
file << node->Id() << " [label=\"";
file << "n:" << node->Id() << "\n";
file << "aest:" << node->getAest() << "\n";
file << "alst:" << node->getAlst() << "\n";
file << "weight:" << node->getWeight() << "\"]\n";
}
file << " }\n";
}
}
else {
for (const auto& node : nodes) {
file << node.Id() << " [label=\"";
file << "n:" << node.Id() << "\n";
file << "aest:" << node.getAest() << "\n";
file << "alst:" << node.getAlst() << "\n";
file << "weight:" << node.getWeight() << "\"]\n";
}
}
for (const auto& node : nodes)
for (const auto& child : node.children) {
file << node.Id() << " -> " << child.first->Id();
file << " [label=\"" << child.second << "\"]\n";
}
file << "}\n";
file.flush();
file.close();
}
} // namespace dcp_graph
@@ -1,57 +0,0 @@
#pragma once
#include "llvm/ADT/StringRef.h"
#include <chrono>
#include <list>
#include <vector>
#include "Task.hpp"
#include "Utils.hpp"
// Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase
// profiling. In normal builds the logger methods are no-ops and helpers compile
// away.
#define DCP_DEBUG_ENABLED
#ifdef DCP_DEBUG_ENABLED
#define DCP_DEBUG_IF(...) __VA_ARGS__
#else
#define DCP_DEBUG_IF(...)
#endif
namespace dcp_graph {
class DcpProgressLogger {
public:
explicit DcpProgressLogger(size_t totalTasks);
void recordFindDuration(double seconds);
void recordSelectDuration(double seconds);
void recordUpdateDuration(double seconds);
void advanceCompleted(size_t taskCount = 1);
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void
printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force);
#ifdef DCP_DEBUG_ENABLED
private:
static std::string formatDuration(double seconds);
bool logProgress = false;
size_t totalTasks = 0;
size_t completedTasks = 0;
std::chrono::steady_clock::time_point startTime;
std::chrono::steady_clock::time_point lastProgressPrint;
double findCandidateSeconds = 0.0;
double selectProcessorSeconds = 0.0;
double updateTimingSeconds = 0.0;
#endif
};
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu);
} // namespace dcp_graph
@@ -1,104 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include <vector>
#include "GraphSupport.hpp"
#include "Task.hpp"
#include "UniqueWorklist.hpp"
namespace dcp_graph {
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents) {
llvm::DenseSet<TaskDCP*> reachable;
std::vector<TaskDCP*> worklist;
worklist.reserve(32);
auto enqueueEdges = [&](TaskDCP* task) {
const auto& edges = followParents ? task->parents : task->children;
for (const auto& edge : edges)
if (reachable.insert(edge.first).second)
worklist.push_back(edge.first);
};
enqueueEdges(root);
while (!worklist.empty()) {
TaskDCP* task = worklist.back();
worklist.pop_back();
enqueueEdges(task);
}
return reachable;
}
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) {
return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false), {}};
}
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
const llvm::DenseSet<TaskDCP*>& descendants,
Time dcpl,
Time maxCompletion,
Time secondMaxCompletion,
TaskDCP* maxCompletionTask) {
LocalScheduleSnapshot snapshot;
snapshot.aestBackup.reserve(descendants.size() + 1);
snapshot.aestBackup.emplace_back(task, task->getAest());
for (TaskDCP* descendant : descendants)
snapshot.aestBackup.emplace_back(descendant, descendant->getAest());
snapshot.dcpl = dcpl;
snapshot.maxCompletion = maxCompletion;
snapshot.secondMaxCompletion = secondMaxCompletion;
snapshot.maxCompletionTask = maxCompletionTask;
return snapshot;
}
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
Time& dcpl,
Time& maxCompletion,
Time& secondMaxCompletion,
TaskDCP*& maxCompletionTask) {
for (const auto& [task, aest] : snapshot.aestBackup)
task->setAest(aest);
dcpl = snapshot.dcpl;
maxCompletion = snapshot.maxCompletion;
secondMaxCompletion = snapshot.secondMaxCompletion;
maxCompletionTask = snapshot.maxCompletionTask;
}
int countDependencyParents(const TaskDCP* task) {
return static_cast<int>(llvm::count_if(task->parents, [](const Edge& edge) { return !edge.isScheduling; }));
}
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion) {
if (insertion == nullptr)
return;
auto alreadyRecorded =
llvm::any_of(insertion->topologicalMoves,
[task](const TaskInsertion::TopologicalMoveRecord& move) { return move.task == task; });
if (alreadyRecorded)
return;
insertion->topologicalMoves.push_back({task, onnx_mlir::LabeledList<TaskDCP>::next(task)});
}
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount) {
UniqueWorkList<std::vector<TaskDCP*>> worklist(roots);
worklist.reserve(nodeCount);
size_t index = 0;
while (index != worklist.size()) {
bool modified = true;
while (modified) {
modified = false;
for (const auto& child : worklist.at(index)->children)
if (worklist.allElementsContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge edge) { return edge.first; }))
modified |= worklist.pushBack(child.first);
}
index++;
}
return {worklist.begin(), worklist.end()};
}
} // namespace dcp_graph
@@ -1,41 +0,0 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include <utility>
#include <vector>
#include "Graph.hpp"
namespace dcp_graph {
struct LocalScheduleSnapshot {
llvm::SmallVector<std::pair<TaskDCP*, Time>, 64> aestBackup;
Time dcpl = 0;
Time maxCompletion = 0;
Time secondMaxCompletion = 0;
TaskDCP* maxCompletionTask = nullptr;
};
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents);
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate);
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
const llvm::DenseSet<TaskDCP*>& descendants,
Time dcpl,
Time maxCompletion,
Time secondMaxCompletion,
TaskDCP* maxCompletionTask);
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
Time& dcpl,
Time& maxCompletion,
Time& secondMaxCompletion,
TaskDCP*& maxCompletionTask);
int countDependencyParents(const TaskDCP* task);
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion);
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount);
} // namespace dcp_graph
@@ -1,66 +0,0 @@
#include <optional>
#include "Graph.hpp"
#include "Task.hpp"
#include "UniqueWorklist.hpp"
std::optional<Edge> TaskDCP::addChild(TaskDCP* child, Weight weight, bool isScheduling) {
std::optional<Edge> oldEdge = std::nullopt;
auto foundElement = std::find_if(children.begin(), children.end(), [child, isScheduling](Edge element) {
return child == element.first && isScheduling == element.isScheduling;
});
if (foundElement != children.end()) {
oldEdge = *foundElement;
fastRemove(children, foundElement);
}
children.emplace_back(Edge {child, weight, isScheduling});
return oldEdge;
}
std::optional<Edge> TaskDCP::addParent(TaskDCP* parent, Weight weight, bool isScheduling) {
std::optional<Edge> oldEdge = std::nullopt;
auto foundElement = std::find_if(parents.begin(), parents.end(), [parent, isScheduling](Edge element) {
return parent == element.first && isScheduling == element.isScheduling;
});
if (foundElement != parents.end()) {
oldEdge = *foundElement;
fastRemove(parents, foundElement);
}
parents.emplace_back(Edge {parent, weight, isScheduling});
return oldEdge;
}
bool TaskDCP::hasDescendant(TaskDCP* child) {
UniqueWorkList<std::vector<TaskDCP*>> worklist;
worklist.reserve(32);
worklist.pushBack(this);
while (!worklist.empty()) {
TaskDCP* task = worklist.back();
worklist.popBack();
if (task == child)
return true;
for (auto edge : task->children)
worklist.pushBack(edge.first);
}
return false;
}
Weight TaskDCP::computeWeightOnCpu(GraphDCP* graph, CPU cpu) {
if (crossbarUsage != 0 && graph->wouldExhaustCrossbarCapacity(cpu, this))
return std::numeric_limits<Weight>::max();
return baseWeight;
}
void TaskInsertion::rollBack() {
graph->removeTaskFromCPU(cpuModified, taskInserted);
if (beforeNode.has_value()) {
auto edgePair = *beforeNode;
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
}
if (afterNode.has_value()) {
auto edgePair = *afterNode;
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
}
// for (auto it = topologicalMoves.rbegin(); it != topologicalMoves.rend(); ++it)
// graph->topologicalOrder.moveBefore(it->task, it->nextTask);
}
@@ -1,126 +0,0 @@
#pragma once
#include <cassert>
#include <optional>
#include <vector>
#include "Utils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
onnx_mlir::spatial::SpatCompute spatCompute;
Time aest;
Time alst;
std::optional<CPU> scheduledCpu;
Weight weight;
Weight baseWeight;
CrossbarUsage crossbarUsage;
long long flag = 0;
int64_t syntheticId = -1;
std::optional<Edge> addChild(TaskDCP* child, Weight weight, bool isScheduling);
std::optional<Edge> addChild(TaskDCP& child, Weight weight, bool isScheduling) {
return addChild(&child, weight, isScheduling);
}
void removeChild(TaskDCP* toRemove, bool isScheduling) { fastRemove(children, toRemove, isScheduling); }
void removeChild(TaskDCP& toRemove, bool isScheduling) { fastRemove(children, &toRemove, isScheduling); }
std::optional<Edge> addParent(TaskDCP* parent, Weight weight, bool isScheduling);
std::optional<Edge> addParent(TaskDCP& parent, Weight weight, bool isScheduling) {
return addParent(&parent, weight, isScheduling);
}
void removeParent(TaskDCP* toRemove, bool isScheduling) { fastRemove(parents, toRemove, isScheduling); }
void removeParent(TaskDCP& toRemove, bool isScheduling) { fastRemove(parents, &toRemove, isScheduling); }
public:
std::vector<Edge> parents;
std::vector<Edge> children;
TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatCompute(spatCompute),
aest(0),
alst(0),
scheduledCpu(),
weight(getSpatComputeWeight(spatCompute)),
baseWeight(weight),
crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
syntheticId(-1),
parents(),
children() {}
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatCompute(),
aest(0),
alst(0),
scheduledCpu(),
weight(weight),
baseWeight(weight),
crossbarUsage(crossbarUsage),
flag(0),
syntheticId(id),
parents(),
children() {}
TaskDCP(const TaskDCP& node) = delete;
TaskDCP(TaskDCP&& node) = default;
void setCpu(CPU cpu) { scheduledCpu = cpu; }
std::optional<CPU> getCpu() const { return scheduledCpu; }
void resetCpu() { scheduledCpu = std::nullopt; }
Weight getWeight() const {
if (isScheduled())
return weight;
return baseWeight;
}
void setWeight(Weight value) { weight = value; }
void resetWeight() { weight = baseWeight; }
Weight computeWeightOnCpu(GraphDCP* graph, CPU cpu);
CrossbarUsage getCrossbarUsage() const { return crossbarUsage; }
bool hasParents() const { return parents.size() != 0; }
bool hasChildren() const { return children.size() != 0; }
Time getAest() const { return aest; }
Time getAlst() const { return alst; }
void setAest(Time value) { aest = value; }
void setAlst(Time value) { alst = value; }
bool hasDescendant(TaskDCP* child);
int64_t Id() const {
if (spatCompute)
return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
return syntheticId;
}
bool isCriticalPath() const { return alst == aest; }
bool isScheduled() const { return scheduledCpu.has_value(); }
onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; }
onnx_mlir::LabeledList<TaskDCP>::Iterator getTopologicalIterator() { return getIterator(); }
friend std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling);
friend void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling);
friend Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
};
struct TaskInsertion {
struct TopologicalMoveRecord {
TaskDCP* task;
TaskDCP* nextTask;
};
std::optional<EdgePair> beforeNode;
std::optional<EdgePair> afterNode;
std::vector<TopologicalMoveRecord> topologicalMoves;
CPU cpuModified;
TaskDCP* taskInserted;
GraphDCP* graph;
void rollBack();
};
@@ -1,83 +0,0 @@
#pragma once
#include "llvm/ADT/DenseSet.h"
#include <cassert>
#include <type_traits>
template <typename T, typename = void>
struct HasPopFront : std::false_type {};
template <typename T>
struct HasPopFront<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
template <typename T>
class UniqueWorkList {
using ValueType = typename T::value_type;
T storage;
llvm::DenseSet<ValueType> uniqueElements;
public:
UniqueWorkList() = default;
template <typename RangeT>
UniqueWorkList(const RangeT& from)
: storage() {
for (auto& element : from) {
if (!uniqueElements.contains(element)) {
storage.push_back(element);
uniqueElements.insert(element);
}
}
}
bool empty() const { return storage.empty(); }
void reserve(size_t value) { return storage.reserve(value); }
size_t size() const { return storage.size(); }
ValueType& at(size_t index) { return storage.at(index); }
const ValueType& at(size_t index) const { return storage.at(index); }
ValueType& front() { return storage.front(); }
ValueType& back() { return storage.back(); }
bool pushBack(const ValueType& value) {
if (!uniqueElements.contains(value)) {
storage.push_back(value);
uniqueElements.insert(value);
return true;
}
return false;
}
void popFront() {
if constexpr (HasPopFront<T>::value)
storage.pop_front();
else
assert(false && "Underlying storage type does not support pop_front()");
}
auto cbegin() const { return storage.cbegin(); }
auto cend() const { return storage.cend(); }
void popBack() { storage.pop_back(); }
template <typename Iterator, typename Mapper>
bool allElementsContained(Iterator begin, Iterator end, Mapper map) const {
auto it = begin;
while (it != end) {
if (!uniqueElements.contains(map(*it)))
return false;
std::advance(it, 1);
}
return true;
}
auto begin() { return storage.begin(); }
auto end() { return storage.end(); }
auto begin() const { return storage.begin(); }
auto end() const { return storage.end(); }
};
@@ -1,111 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <cstdint>
#include <limits>
#include <list>
#include <type_traits>
#include <utility>
#include <vector>
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using CPU = int;
using Weight = unsigned long long;
using Time = unsigned long long;
using CrossbarUsage = unsigned long long;
class TaskDCP;
class GraphDCP;
struct Edge {
TaskDCP* first;
Weight second;
bool isScheduling = false;
};
using EdgePair = std::pair<Edge, Edge>;
using IndexedEdge = std::tuple<int64_t, int64_t, int64_t>;
inline void fastRemove(std::vector<Edge>& vector, TaskDCP* toRemove, bool isScheduling) {
auto position = std::find_if(vector.begin(), vector.end(), [toRemove, isScheduling](Edge edge) {
return edge.first == toRemove && edge.isScheduling == isScheduling;
});
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* toRemove) {
auto position =
std::find_if(vector.begin(), vector.end(), [toRemove](TaskDCP* element) { return element == toRemove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
template <typename P>
void fastRemove(std::vector<Edge>& vector, P position) {
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
template <typename T>
inline T checkedAdd(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
assert(lhs <= std::numeric_limits<T>::max() - rhs && "unsigned addition overflow");
return lhs + rhs;
}
template <typename T>
inline T checkedMultiply(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "checkedMultiply only supports unsigned types");
if (lhs == 0 || rhs == 0)
return 0;
assert(lhs <= std::numeric_limits<T>::max() / rhs && "unsigned multiplication overflow");
return lhs * rhs;
}
template <typename T>
inline T addOrMax(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "addOrMax only supports unsigned types");
if (lhs == std::numeric_limits<T>::max() || rhs == std::numeric_limits<T>::max())
return std::numeric_limits<T>::max();
return checkedAdd(lhs, rhs);
}
template <typename T>
inline T subtractOrZero(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "subtractOrZero only supports unsigned types");
if (lhs == std::numeric_limits<T>::max())
return lhs;
if (rhs == std::numeric_limits<T>::max() || lhs <= rhs)
return 0;
return lhs - rhs;
}
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : spatCompute.getBody())
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
CrossbarUsage crossbarUsage = 0;
for (auto& region : spatCompute.getBody())
for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
File diff suppressed because it is too large Load Diff
@@ -10,8 +10,7 @@ namespace spatial {
class MergeScheduleMaterializer {
public:
mlir::LogicalResult
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
};
} // namespace spatial
@@ -1,8 +1,6 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
@@ -27,7 +25,6 @@
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iterator>
#include <memory>
#include <optional>
#include <tuple>
@@ -35,7 +32,7 @@
#include <vector>
#include "MaterializeMergeSchedule.hpp"
#include "PostMergeCompaction.hpp"
#include "Scheduling/ComputeGraph.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
@@ -57,8 +54,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()),
phase(phaseName.str()) {
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
@@ -82,8 +78,6 @@ struct MergeIrCounts {
uint64_t topLevelComputeBatchCount = 0;
uint64_t scalarChannelSendCount = 0;
uint64_t scalarChannelReceiveCount = 0;
uint64_t tensorChannelSendCount = 0;
uint64_t tensorChannelReceiveCount = 0;
uint64_t wvmmCount = 0;
uint64_t vaddCount = 0;
uint64_t scfForCount = 0;
@@ -98,10 +92,6 @@ MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
++counts.scalarChannelSendCount;
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
++counts.scalarChannelReceiveCount;
else if (isa<spatial::SpatChannelSendTensorOp, spatial::SpatChannelSendTensorBatchOp>(nestedOp))
++counts.tensorChannelSendCount;
else if (isa<spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelReceiveTensorBatchOp>(nestedOp))
++counts.tensorChannelReceiveCount;
else if (isa<spatial::SpatVMMOp>(nestedOp))
++counts.wvmmCount;
else if (isa<spatial::SpatVAddOp>(nestedOp))
@@ -130,14 +120,10 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount
<< " compute_batch=" << counts.topLevelComputeBatchCount
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount
<< " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
}
@@ -149,7 +135,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
struct ComputeMotifInfo {
uint64_t instructionCount = 0;
uint64_t weightedMvmCount = 0;
uint64_t weightedVmmCount = 0;
};
@@ -167,21 +152,21 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
}
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) {
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights()))
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
targetWeightIndices[weight].push_back(weightIndex);
DenseMap<Value, size_t> usedSourceWeightOccurrences;
SmallVector<size_t> sourceToTargetIndex;
sourceToTargetIndex.reserve(sourceWeights.size());
auto targetWeights = target.getWeightsMutable();
for (Value weight : sourceWeights) {
size_t occurrence = usedSourceWeightOccurrences[weight]++;
auto& matchingIndices = targetWeightIndices[weight];
if (occurrence >= matchingIndices.size()) {
size_t newIndex = target.getWeights().size();
targetWeights.append(weight);
size_t newIndex = targetWeights.size();
targetWeights.push_back(weight);
matchingIndices.push_back(newIndex);
sourceToTargetIndex.push_back(newIndex);
continue;
@@ -213,37 +198,50 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
auto& computeUse = *compute->getUses().begin();
auto child = cast<SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
auto childInputIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
SmallVector<Value> mergedWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(mergedWeights, child.getWeights());
SmallVector<Value> mergedInputs(compute.getInputs().begin(), compute.getInputs().end());
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), mergedWeights, mergedInputs);
Block* newBody = rewriter.createBlock(&newCompute.getBodyRegion());
for (Value weight : mergedWeights)
newBody->addArgument(weight.getType(), loc);
for (Value input : mergedInputs)
newBody->addArgument(input.getType(), loc);
IRMapping mapper;
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights());
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex]));
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
auto remapWeightIndex = [&](auto weightedOp) {
auto oldIndex = weightedOp.getWeightIndex();
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
};
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
remapWeightIndex(weightedVmmOp);
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
assert(oldWeightArg && newWeightArg && "expected compute weight block arguments");
mapper.map(*oldWeightArg, *newWeightArg);
}
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) {
auto oldInputArg = compute.getInputArgument(inputIndex);
auto newInputArg = newCompute.getInputArgument(inputIndex);
assert(oldInputArg && newInputArg && "expected compute input block arguments");
mapper.map(*oldInputArg, *newInputArg);
}
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) {
auto oldWeightArg = child.getWeightArgument(oldIndex);
auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]);
assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments");
mapper.map(*oldWeightArg, *newWeightArg);
}
rewriter.setInsertionPointToEnd(newBody);
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
for (Operation& op : compute.getBody().front().without_terminator())
rewriter.clone(op, mapper);
auto childInputArg = child.getInputArgument(childInputIndex);
assert(childInputArg && "expected child compute input block argument");
mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
rewriter.setInsertionPointToEnd(newBody);
for (auto& op : child.getBody().front())
rewriter.clone(op, mapper);
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
@@ -283,13 +281,8 @@ void emitMotifProfile(func::FuncOp funcOp) {
for (auto [index, compute] : llvm::enumerate(computes)) {
ComputeMotifInfo& info = computeInfos[index];
for (Operation& op : compute.getBody().front()) {
info.instructionCount++;
if (isa<spatial::SpatMVMOp>(&op))
info.weightedMvmCount++;
if (isa<spatial::SpatVMMOp>(&op))
info.weightedVmmCount++;
}
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
if (info.weightedVmmCount > 0) {
weightedVmmNodeCount++;
weightedVmmOpCount += info.weightedVmmCount;
@@ -400,7 +393,7 @@ void emitMotifProfile(func::FuncOp funcOp) {
wideWeightedVmmLevels256 += count >= 256;
}
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
SmallVector<ShapeKey> weightedVmmShapeKeys;
for (auto [index, compute] : llvm::enumerate(computes)) {
const ComputeMotifInfo& info = computeInfos[index];
@@ -408,7 +401,6 @@ void emitMotifProfile(func::FuncOp funcOp) {
continue;
weightedVmmShapeKeys.push_back({info.instructionCount,
info.weightedVmmCount,
info.weightedMvmCount,
static_cast<uint64_t>(compute.getWeights().size()),
static_cast<uint64_t>(compute.getInputs().size()),
static_cast<uint64_t>(parents[index].size()),
@@ -461,14 +453,13 @@ void emitMotifProfile(func::FuncOp funcOp) {
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
auto [count, shape] = weightedVmmShapeCounts[rank];
auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape;
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
"mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n",
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
rank,
count,
insts,
vmmOps,
mvmOps,
weights,
inputs,
fanIn,
@@ -485,7 +476,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
struct ReportRow {
uint64_t id = 0;
uint64_t logicalComputeCount = 0;
uint64_t weightCount = 0;
uint64_t crossbarCount = 0;
uint64_t instructionCount = 0;
bool isRebatched = false;
SmallVector<int32_t> coreIds;
@@ -495,38 +486,40 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalWeightCount = 0;
uint64_t totalCrossbarCount = 0;
uint64_t nextBatchId = 0;
std::vector<ReportRow> collectedData;
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
};
for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
uint64_t numInst = 0;
for (auto& _ : spatCompute.getRegion().front())
++numInst;
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds});
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalWeightCount += spatCompute.getWeights().size();
totalCrossbarCount += perInstanceCrossbarCount;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
uint64_t numInst = 0;
for (auto& _ : batch.getRegion().front())
++numInst;
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
totalComputeOps += 1;
totalLogicalComputes += logicalCount;
totalBatchComputeOps += 1;
totalInstructionCount += numInst * logicalCount;
totalWeightCount += batch.getWeights().size();
totalCrossbarCount += perInstanceCrossbarCount * logicalCount;
}
}
@@ -536,7 +529,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
{"Number of logical computes", std::to_string(totalLogicalComputes) },
{"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
{"Number of instructions", std::to_string(totalInstructionCount)},
{"Number of used crossbars", std::to_string(totalWeightCount) }
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
@@ -550,7 +543,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) {
ReportRow next = collectedData[nI];
if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount
if (current.isRebatched == next.isRebatched && current.crossbarCount == next.crossbarCount
&& current.instructionCount == next.instructionCount
&& current.logicalComputeCount == next.logicalComputeCount)
lastIndex = nI;
@@ -583,20 +576,20 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
os << ":\n";
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
uint64_t perCoreInstructionCount = current.instructionCount;
uint64_t perCoreWeightCount =
current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount;
uint64_t perCoreCrossbarCount =
current.logicalComputeCount == 0 ? 0 : current.crossbarCount / current.logicalComputeCount;
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
llvm::SmallVector<ReportField, 3> perCoreFields = {
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
{"Number of instructions", std::to_string(perCoreInstructionCount) },
{"Number of used crossbars", std::to_string(perCoreWeightCount) }
{"Number of used crossbars", std::to_string(perCoreCrossbarCount) }
};
if (current.isRebatched) {
llvm::SmallVector<ReportField, 3> totalEntryFields = {
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
{"Number of instructions", std::to_string(totalEntryInstructionCount) },
{"Number of used crossbars", std::to_string(current.weightCount) }
{"Number of used crossbars", std::to_string(current.crossbarCount) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
}
@@ -651,13 +644,6 @@ public:
emitMergeIrCounts("after-materialization", func);
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
signalPassFailure();
return;
}
emitMergeIrCounts("after-post-merge-compaction", func);
{
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
if (!sortTopologically(&func.getBody().front())) {
@@ -667,7 +653,7 @@ public:
}
emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "dcp_merge_report", analysisResult->cpuToLastComputeMap.size());
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
}
}
};
@@ -1,459 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <chrono>
#include <cstdlib>
#include <limits>
#include <optional>
#include "PostMergeCompaction.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
return std::nullopt;
}
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
return static_cast<uint64_t>(phaseAttr.getInt());
return std::nullopt;
}
struct RebatchKey {
unsigned inputCount = 0;
unsigned resultCount = 0;
unsigned weightCount = 0;
uint64_t phase = 0;
bool hasPhase = false;
uint64_t structureHash = 0;
bool operator==(const RebatchKey& other) const {
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
}
};
struct RebatchKeyInfo {
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
static unsigned getHashValue(const RebatchKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
}
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
};
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
RebatchKey computeRebatchKey(SpatCompute compute) {
llvm::hash_code structureHash =
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
for (Value weight : compute.getWeights())
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
structureHash = llvm::hash_combine(structureHash, *phase);
Block& body = compute.getBody().front();
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
for (BlockArgument arg : body.getArguments())
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
for (Operation& op : body) {
structureHash = llvm::hash_combine(
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
for (Type type : op.getResultTypes())
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
for (NamedAttribute attr : op.getAttrs())
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
}
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
return {static_cast<unsigned>(compute.getInputs().size()),
static_cast<unsigned>(compute.getResultTypes().size()),
static_cast<unsigned>(compute.getWeights().size()),
phase.value_or(0),
phase.has_value(),
static_cast<uint64_t>(structureHash)};
}
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
if (!lhs || !rhs)
return false;
if (lhs.getInputs().size() != rhs.getInputs().size())
return false;
if (lhs.getResultTypes() != rhs.getResultTypes())
return false;
if (lhs.getWeights().size() != rhs.getWeights().size())
return false;
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
return false;
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
return false;
auto& lhsBlock = lhs.getBody().front();
auto& rhsBlock = rhs.getBody().front();
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
return false;
DenseMap<Value, Value> mappedValues;
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
if (lhsArg.getType() != rhsArg.getType())
return false;
mappedValues[lhsArg] = rhsArg;
}
auto lhsIt = lhsBlock.begin();
auto rhsIt = rhsBlock.begin();
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
Operation& lhsOp = *lhsIt;
Operation& rhsOp = *rhsIt;
if (lhsOp.getName() != rhsOp.getName())
return false;
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
return false;
if (lhsOp.getNumResults() != rhsOp.getNumResults())
return false;
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
return false;
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
auto mapped = mappedValues.find(lhsOperand);
if (mapped != mappedValues.end()) {
if (mapped->second != rhsOperand)
return false;
continue;
}
if (lhsOperand != rhsOperand)
return false;
}
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
return false;
}
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
return false;
}
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
return false;
}
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
return false;
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
mappedValues[lhsResult] = rhsResult;
}
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseSet<Operation*> consumed;
DenseMap<Operation*, size_t> computeOrder;
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
for (auto [index, compute] : llvm::enumerate(computes)) {
computeOrder[compute.getOperation()] = index;
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
}
for (size_t index = 0; index < computes.size(); ++index) {
auto anchor = computes[index];
if (consumed.contains(anchor))
continue;
if (anchor.getInputs().size() > 1)
continue;
if (!anchor.getResults().empty())
continue;
SmallVector<SpatCompute> group {anchor};
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
if (auto coreId = getComputeCoreId(anchor))
usedCoreIds.insert(*coreId);
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
if (bucketIt == candidatesByKey.end())
continue;
for (auto candidate : bucketIt->second) {
if (computeOrder.lookup(candidate.getOperation()) <= index)
continue;
if (consumed.contains(candidate))
continue;
if (!areEquivalentForRebatch(anchor, candidate))
continue;
if (auto coreId = getComputeCoreId(candidate))
if (!usedCoreIds.insert(*coreId).second)
continue;
group.push_back(candidate);
}
if (group.size() <= 1)
continue;
auto insertionAnchor = group.front();
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
llvm::stable_sort(
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
}
SmallVector<Value> weights;
weights.reserve(group.size() * anchor.getWeights().size());
SmallVector<Value> inputs;
inputs.reserve(group.size() * anchor.getInputs().size());
SmallVector<int32_t> coreIds;
coreIds.reserve(group.size());
bool haveAllCoreIds = true;
for (auto compute : group) {
llvm::append_range(weights, compute.getWeights());
llvm::append_range(inputs, compute.getInputs());
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
haveAllCoreIds = false;
else if (haveAllCoreIds)
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
}
rewriter.setInsertionPoint(insertionAnchor);
auto rebatched = SpatComputeBatch::create(rewriter,
insertionAnchor.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
ValueRange(weights),
ValueRange(inputs));
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
auto* newBlock =
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(newBlock);
IRMapping mapper;
auto& anchorBlock = anchor.getBody().front();
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
for (Operation& anchorOp : anchorBlock) {
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
entries.push_back(
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchReceiveEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getOutput().getType(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
continue;
}
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchSendEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
spatial::SpatChannelSendBatchOp::create(rewriter,
sendOp.getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
mapper.lookup(sendOp.getInput()));
continue;
}
if (isa<spatial::SpatYieldOp>(anchorOp)) {
for (auto& opIt : opIts)
++opIt;
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
continue;
}
Operation* cloned = rewriter.clone(anchorOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
for (auto& opIt : opIts)
++opIt;
}
for (auto compute : group) {
compute->removeAttr(kRebatchPhaseAttrName);
consumed.insert(compute);
rewriter.eraseOp(compute);
}
}
for (auto compute : funcOp.getOps<SpatCompute>())
compute->removeAttr(kRebatchPhaseAttrName);
}
void cleanupDeadPackingOps(func::FuncOp funcOp) {
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
} // namespace
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
{
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
orderBilateralChannelOps(funcOp);
}
{
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
rebatchEquivalentComputes(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-regular-op-runs");
compactRegularOpRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
compactRowWiseWvmmRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp);
}
return success();
}
} // namespace onnx_mlir
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
} // namespace onnx_mlir

Some files were not shown because too many files have changed in this diff Show More