Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| a103ba328b | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 |
+80
-12
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
||||
|
||||
project(raptor)
|
||||
|
||||
# Add symlink to PIM as accelerator in onnx-mlir
|
||||
function(raptor_ensure_symlink link_path target_path)
|
||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
||||
# Materialize a CMake shim directory
|
||||
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||
|
||||
if(NOT EXISTS "${link_parent}")
|
||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
||||
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||
message(FATAL_ERROR
|
||||
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||
" ${real_external_source_dir}"
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (IS_SYMLINK "${shim_dir}")
|
||||
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||
file(REMOVE "${shim_dir}")
|
||||
endif ()
|
||||
|
||||
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||
endif ()
|
||||
|
||||
file(MAKE_DIRECTORY "${shim_dir}")
|
||||
|
||||
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||
set(shim_contents
|
||||
"get_filename_component(raptor_external_source_dir
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
"
|
||||
)
|
||||
|
||||
if (EXISTS "${shim_file}")
|
||||
file(READ "${shim_file}" old_contents)
|
||||
else ()
|
||||
set(old_contents "")
|
||||
endif ()
|
||||
|
||||
if (NOT old_contents STREQUAL shim_contents)
|
||||
file(WRITE "${shim_file}" "${shim_contents}")
|
||||
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||
else ()
|
||||
message(STATUS "CMake shim already up to date for ${description}")
|
||||
endif ()
|
||||
|
||||
# Mirror the external tree's first-level entries into the shim directory
|
||||
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||
|
||||
foreach (child IN LISTS children)
|
||||
if (child STREQUAL "CMakeLists.txt")
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
set(real_child "${real_external_source_dir}/${child}")
|
||||
set(shim_child "${shim_dir}/${child}")
|
||||
|
||||
if (IS_SYMLINK "${shim_child}")
|
||||
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||
if (existing_link_target STREQUAL real_child)
|
||||
continue()
|
||||
endif ()
|
||||
file(REMOVE_RECURSE "${shim_child}")
|
||||
elseif (EXISTS "${shim_child}")
|
||||
# Do not delete real files/directories. This protects the generated shim.
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
if(NOT EXISTS "${link_path}")
|
||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
||||
file(CREATE_LINK
|
||||
"${target_path}"
|
||||
"${link_path}"
|
||||
"${real_child}"
|
||||
"${shim_child}"
|
||||
SYMBOLIC
|
||||
)
|
||||
endif()
|
||||
endforeach ()
|
||||
endfunction()
|
||||
|
||||
raptor_ensure_symlink(
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
"PIM accelerator"
|
||||
)
|
||||
raptor_ensure_symlink(
|
||||
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
"PIM accelerator tests"
|
||||
)
|
||||
|
||||
# Patch onnx-mlir sources for PIM accelerator support.
|
||||
|
||||
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||
run only the codegen tail.
|
||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||
per-core count.
|
||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||
merge-compute-nodes pass (default: `peft`).
|
||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||
@@ -142,6 +145,46 @@ validate.py \
|
||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||
```
|
||||
|
||||
Each validation run writes debugging artifacts into the benchmark's workspace
|
||||
directory (for example `validation/operations/gemm/small/`):
|
||||
- `inputs/` — generated input CSVs used for the run.
|
||||
- `outputs/` — reference outputs dumped by the native ONNX runner.
|
||||
- `raptor/` — compiler artifacts:
|
||||
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
|
||||
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
|
||||
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
|
||||
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
|
||||
`raptor/reports/` such as `dcp_merge_report.txt`,
|
||||
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
|
||||
- `runner/` — generated reference runner source, build tree, and shared library.
|
||||
- `simulation/out.bin` — raw simulator output dump used for output comparison.
|
||||
|
||||
That means you usually do not need to rerun standalone `--EmitSpatial` or
|
||||
`--EmitPim` commands while debugging validation failures: the per-pass dialect
|
||||
dumps are already available under `raptor/dialects/`.
|
||||
|
||||
The validator does not currently expose a simulator tracing flag, but once a
|
||||
validation has produced `raptor/pim/` you can rerun the simulator manually with
|
||||
tracing enabled:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo run --no-default-features --features tracing --release \
|
||||
--package pim-simulator --bin pim-simulator -- \
|
||||
-f /path/to/workspace/raptor/pim \
|
||||
-o /path/to/workspace/simulation/out.bin \
|
||||
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||
```
|
||||
|
||||
With `--features tracing`, the simulator writes per-core traces as
|
||||
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
|
||||
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
|
||||
and the model output shapes. If you need a clean slate before rerunning, use:
|
||||
|
||||
```bash
|
||||
validate.py --clean
|
||||
```
|
||||
|
||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||
|
||||
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init(executor.cpu().num_core(), args.output.clone());
|
||||
executor.execute();
|
||||
executor.execute()?;
|
||||
dump_memory(executor, &args)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
time::{Duration, SystemTime},
|
||||
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
struct DeadlockInfo {
|
||||
cycle: String,
|
||||
states: String,
|
||||
}
|
||||
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||
print_status(cores_instructions);
|
||||
check_cycle(cpu, cores_instructions, send_recv);
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
if cores_instructions
|
||||
.iter()
|
||||
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||
{
|
||||
bail!("Execution stalled with unfinished instructions");
|
||||
}
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
TRACER.lock().unwrap().report();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
|
||||
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CoreState {
|
||||
SendingTo(i32),
|
||||
ReceivingFrom(i32),
|
||||
SendingTo(i32, i32),
|
||||
ReceivingFrom(i32, i32),
|
||||
Working,
|
||||
Halted,
|
||||
}
|
||||
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
let (this_core, target_core) = data.get_core_immcore();
|
||||
|
||||
if isa_recv(functor_address) {
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core));
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||
} else if isa_send(functor_address) {
|
||||
states.insert(this_core, CoreState::SendingTo(target_core));
|
||||
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||
} else {
|
||||
states.insert(this_core, CoreState::Working);
|
||||
}
|
||||
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
|
||||
for (&core_id, state) in states.iter() {
|
||||
match state {
|
||||
CoreState::SendingTo(target_core) => {
|
||||
CoreState::SendingTo(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::ReceivingFrom(core_id) {
|
||||
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::ReceivingFrom(target_core) => {
|
||||
CoreState::ReceivingFrom(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::SendingTo(core_id) {
|
||||
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
let cycle = cycle
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core, size, target)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core, size, source)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core),
|
||||
CoreState::Halted => format!("core {} halted", core),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
|
||||
// bail!("Deadlock detected: {}", cycle_msg);
|
||||
break; // Stop tracing
|
||||
return Some(DeadlockInfo {
|
||||
cycle: cycle_msg,
|
||||
states: states_msg,
|
||||
});
|
||||
}
|
||||
|
||||
// Hit a known branch that didn't result in a cycle
|
||||
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
current_core = waiting_for;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_pim_library(OMPimCommon
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
@@ -110,6 +153,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
@@ -118,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||
return current->getBlock();
|
||||
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
@@ -12,6 +13,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
@@ -29,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
|
||||
@@ -21,12 +21,13 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
found |= mvmOp.getWeight() == weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
found |= vmmOp.getWeight() == weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
@@ -7,10 +7,34 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
|
||||
@@ -24,113 +28,132 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
||||
IRRewriter rewriter(scalarCore.getContext());
|
||||
SmallVector<Operation*> batchOps;
|
||||
scalarCore.walk([&](Operation* op) {
|
||||
if (isa<pim::PimSendBatchOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
batchOps.push_back(op);
|
||||
}
|
||||
});
|
||||
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
OperationFolder& constantFolder) {
|
||||
Block& oldBlock = coreBatchOp.getBody().front();
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightCount = coreBatchOp.getWeights().size();
|
||||
|
||||
for (Operation* op : batchOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (blockArg.getType().isIndex()) {
|
||||
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (argIndex <= weightCount) {
|
||||
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t inputIndex = argIndex - 1 - weightCount;
|
||||
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
||||
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
pim::PimSendOp::create(
|
||||
builder,
|
||||
sendBatchOp.getLoc(),
|
||||
sendBatchOp.getInput(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
rewriter.eraseOp(op);
|
||||
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
sendTensorBatchOp.getInput(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
rewriter.eraseOp(op);
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(rewriter,
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
auto scalarReceive = pim::PimReceiveOp::create(
|
||||
builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
receiveBatchOp.getOutputBuffer(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
receiveTensorBatchOp.getOutputBuffer(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
memcpBatchOp.getDeviceTarget(),
|
||||
memcpBatchOp.getHostSource(),
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
assert(!lanes.empty() && "expected at least one batch lane");
|
||||
|
||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||
OpBuilder builder(scratchModule->getContext());
|
||||
OperationFolder constantFolder(scratchModule->getContext());
|
||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
int32_t coreId = coreIds[lanes.front()];
|
||||
for (unsigned lane : lanes)
|
||||
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
||||
|
||||
auto scalarCore =
|
||||
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
for (unsigned lane : lanes)
|
||||
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,5 +9,8 @@ namespace onnx_mlir {
|
||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
llvm::ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -41,15 +41,23 @@ using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (elementType.isIndex())
|
||||
return sizeof(int64_t);
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() / 8;
|
||||
llvm_unreachable("unsupported shaped element type");
|
||||
}
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
||||
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
||||
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
||||
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("ld",
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
*deviceTargetOffset,
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
*hostSourceOffset,
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
||||
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("st",
|
||||
addressOf(storeOp.getHostTarget(), knowledge),
|
||||
storeOp.getHostTargetOffset(),
|
||||
*hostTargetOffset,
|
||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||
storeOp.getDeviceSourceOffset(),
|
||||
*deviceSourceOffset,
|
||||
storeOp.getSize());
|
||||
}
|
||||
|
||||
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
/// Returns the number of emitted instructions, or -1 on failure.
|
||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
||||
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return std::nullopt;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
};
|
||||
size_t processedOperations = 0;
|
||||
auto result =
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||
auto weightIndex = resolveWeightIndex(vmmOp);
|
||||
if (!weightIndex)
|
||||
return failure();
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
||||
SmallVector<size_t> orderedOriginalCoreIds;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
||||
if (inserted)
|
||||
orderedOriginalCoreIds.push_back(originalCoreId);
|
||||
it->second.push_back(lane);
|
||||
}
|
||||
|
||||
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
MemoryReportRow laneRow;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||
"pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||
|
||||
llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||
llvm::cl::init(4000));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||
|
||||
void verifyExplicitPimCoreCount() {
|
||||
if (!hasExplicitPimCoreCount())
|
||||
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||
if (coresCount.getValue() <= 0)
|
||||
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,8 +20,14 @@ typedef enum {
|
||||
EmitPimCodegen = 3
|
||||
} PimEmissionTargetType;
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
MergeSchedulerDcp = 1,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
// wanted tile is generated by two separate operands of the concat. If this is
|
||||
|
||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) {
|
||||
verifyExplicitPimCoreCount();
|
||||
|
||||
if (pimOnlyCodegen) {
|
||||
// Skip all the lowering passes and directly generate code for PIM.
|
||||
|
||||
@@ -33,7 +33,7 @@ struct DenseWeightView {
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -70,7 +88,8 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
@@ -80,6 +99,28 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return view;
|
||||
@@ -87,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
@@ -18,13 +18,17 @@ namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
|
||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -136,7 +136,7 @@ tileMatrix(mlir::Value& matrixToTile,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,17 +5,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -46,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||
if (!computes.empty() || !computeBatches.empty())
|
||||
return;
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
@@ -87,17 +86,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
ConversionTarget preTarget(*ctx);
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
||||
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -130,30 +180,17 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
||||
<< coresCount << ")";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
@@ -162,14 +199,29 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,9 +488,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
@@ -599,10 +497,240 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -402,13 +402,29 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||
for (Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
for (Value input : aHSlices[coreId]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
Block* body =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
||||
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
@@ -416,10 +432,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
@@ -502,9 +515,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -533,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
TypeRange {outType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||
SmallVector<Location> blockArgLocs(4, loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value lane = batchOp.getLaneArgument();
|
||||
Value weight = batchOp.getWeightArgument(0);
|
||||
Value packedInput = batchOp.getInputArgument(0);
|
||||
Value packedOutput = batchOp.getOutputArgument(0);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
||||
unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
ArrayRef<Value> outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = inputType.getRank();
|
||||
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
for (Value outerIndex : outerIndices) {
|
||||
offsets.push_back(outerIndex);
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||
|
||||
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value loopIndex = loop.getInductionVar();
|
||||
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||
outerIndices.push_back(loopIndex);
|
||||
Value updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||
outerIndices.pop_back();
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
if (inputType.getRank() == 1) {
|
||||
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||
return;
|
||||
}
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||
SmallVector<Value> outerIndices;
|
||||
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
|
||||
if (axis == softmaxAxis)
|
||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> rebuiltSlices;
|
||||
rebuiltSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
|
||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
return replaceWithReshape([&](Value data) -> Value {
|
||||
Value reshaped = data;
|
||||
if (sourceType.getRank() != 1) {
|
||||
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||
reshaped = tensor::CollapseShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||
}
|
||||
if (resultType.getRank() == 1)
|
||||
return reshaped;
|
||||
return tensor::ExpandShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||
.getResult();
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -15,42 +15,88 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
}
|
||||
|
||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
||||
}
|
||||
|
||||
static Value buildNearestResize(Value input,
|
||||
ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> outputShape,
|
||||
int64_t axis,
|
||||
static Value buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
||||
return input;
|
||||
auto elemType = resultType.getElementType();
|
||||
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(outputShape[axis]);
|
||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
|
||||
Value outputN = batchLoop.getInductionVar();
|
||||
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||
|
||||
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC =
|
||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH =
|
||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW =
|
||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(widthLoop);
|
||||
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(heightLoop);
|
||||
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(channelLoop);
|
||||
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
return batchLoop.getResult(0);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result =
|
||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
|
||||
@@ -31,46 +31,18 @@ static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
@@ -81,11 +53,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -116,8 +86,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -126,14 +104,17 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
@@ -165,11 +146,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -205,8 +184,25 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
|
||||
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -215,31 +211,28 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
@@ -247,10 +240,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
@@ -262,4 +251,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,9 +3,13 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
|
||||
auto entryFunc = getPimEntryFunc(module);
|
||||
if (failed(entryFunc)) {
|
||||
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -17,6 +18,37 @@ namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
||||
if (failed(constantValue))
|
||||
return failure();
|
||||
constants.push_back(*constantValue);
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
|
||||
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
@@ -28,27 +60,30 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
return success();
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
@@ -56,24 +91,26 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
||||
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
|
||||
"materialize explicit communication before lowering to PIM");
|
||||
|
||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
@@ -102,7 +139,12 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
||||
mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
@@ -142,20 +184,31 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
continue;
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
sendBatchOp.getTargetCoreIdsAttr());
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
||||
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
@@ -163,14 +216,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
receiveBatchOp.getSourceCoreIdsAttr())
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
||||
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -178,6 +232,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
auto clonedTensor = cloned->getResult(0);
|
||||
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
|
||||
mapper.map(toTensorOp.getResult(), clonedTensor);
|
||||
continue;
|
||||
}
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
@@ -194,9 +252,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
}
|
||||
}
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||
continue;
|
||||
if (isExplicitHostOperand(&op, operandIndex))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||
|
||||
@@ -22,6 +22,8 @@ add_pim_library(OMSpatialToPim
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRSCFUtils
|
||||
MLIRTransformUtils
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
@@ -12,15 +13,24 @@ namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getInput(),
|
||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||
pim::PimSendOp::create(
|
||||
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||
op.getSourceCoreId())
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
@@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTens
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = toPimCoreId(targetCoreId);
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelRecei
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = toPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
|
||||
@@ -29,7 +29,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
||||
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
|
||||
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
||||
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
bodyArgument.replaceAllUsesWith(replacement);
|
||||
@@ -37,7 +40,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
compute.getInputsMutable().erase(inputIndex);
|
||||
else
|
||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
body.eraseArgument(inputIndex);
|
||||
body.eraseArgument(bodyArgIndex);
|
||||
rewriter.finalizeOpModification(owner);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
@@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
@@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 0)
|
||||
if (block.getNumArguments() != computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
@@ -110,8 +131,10 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights()))
|
||||
mapping.map(computeOp.getWeightArgument(weightIndex), weight);
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -133,7 +156,7 @@ void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -143,21 +166,42 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
||||
if (!receiveOp || blockArg.use_empty())
|
||||
continue;
|
||||
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (receiveOp && !blockArg.use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
Value received = PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg.use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
@@ -197,11 +241,36 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (blockArg.use_empty())
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||
if (!inputType)
|
||||
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
|
||||
auto copied =
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(copied);
|
||||
}
|
||||
if (!computeOp.getInputs().empty())
|
||||
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,6 +12,7 @@ struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
mlir::OperationFolder& constantFolder;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
@@ -76,8 +76,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
continue;
|
||||
@@ -89,14 +88,13 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
continue;
|
||||
@@ -108,7 +106,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -254,7 +252,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
Value hostConstant = constantOp.getResult();
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
@@ -264,40 +262,22 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -318,7 +320,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -327,7 +330,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -337,15 +345,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
static void cloneHelperChain(Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
|
||||
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||
for (Operation* op : helperChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -360,14 +371,19 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
int32_t sizeInBytes,
|
||||
OperationFolder& constantFolder) {
|
||||
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
|
||||
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
hostTargetOffsetValue,
|
||||
deviceSourceOffsetValue,
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
@@ -411,69 +427,84 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
Location loc = producerOp->getLoc();
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
|
||||
if (auto returnUse = analyzeReturnUse(result)) {
|
||||
Value storedValue = yieldValue;
|
||||
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
||||
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
currentStoredValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto resultUses = result.getUses();
|
||||
auto resultUses = producedValue.getUses();
|
||||
if (rangeLength(resultUses) == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
storedValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
storedValue,
|
||||
static_cast<int32_t>(flatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError(
|
||||
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
@@ -484,7 +515,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
@@ -503,7 +534,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
auto scalarTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
@@ -513,7 +544,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
static_cast<int32_t>(elementSize),
|
||||
constantFolder);
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
@@ -521,6 +553,11 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
@@ -569,7 +606,16 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
markOpToRemove(state, concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||
markOpToRemove(state, receiveOp);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
|
||||
@@ -32,6 +32,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute compu
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
|
||||
>;
|
||||
|
||||
def spatToPimVMM : Pat<
|
||||
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(SpatVMMOp:$srcOpRes $weight, $vector),
|
||||
(PimVMMOp $weight, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -12,6 +13,8 @@
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -104,23 +107,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
|
||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||
return PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
outputBuffer,
|
||||
zeroValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
sizeAttr)
|
||||
.getOutput();
|
||||
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
static Value
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
@@ -131,25 +145,27 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
coreId = 0;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
OperationFolder constantFolder(&getContext());
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect,
|
||||
@@ -169,34 +185,32 @@ void SpatialToPimPass::runOnOperation() {
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
populateWithGenerated(initialPatterns);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(patterns);
|
||||
RewritePatternSet globalTensorPatterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
|
||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||
}
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -205,17 +219,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
@@ -229,7 +242,6 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
@@ -238,6 +250,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -247,30 +260,27 @@ void SpatialToPimPass::runOnOperation() {
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
|
||||
{
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
@@ -291,12 +301,13 @@ void SpatialToPimPass::runOnOperation() {
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
||||
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -306,6 +317,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
@@ -313,7 +325,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
@@ -340,10 +352,13 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
if (!elementType.isIntOrFloat())
|
||||
return;
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
@@ -353,10 +368,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(
|
||||
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
||||
deviceTensor,
|
||||
inputTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
|
||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||
|
||||
+34
-16
@@ -2,6 +2,7 @@
|
||||
#define PIM_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
@@ -24,7 +25,8 @@ def PimTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute a block on a PIM core";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
I32Attr:$coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute equivalent batched core bodies";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
I32Attr:$sourceCoreId
|
||||
Index:$sourceCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$deviceTargetOffset,
|
||||
Index:$hostSourceOffset,
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict
|
||||
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$hostTargetOffset,
|
||||
Index:$deviceSourceOffset,
|
||||
PimTensor:$hostTarget,
|
||||
PimTensor:$deviceSource,
|
||||
I32Attr:$hostTargetOffset,
|
||||
I32Attr:$deviceSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict
|
||||
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Vector-matrix multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$weight,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
|
||||
type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,41 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
||||
@@ -20,6 +20,80 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
|
||||
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
||||
printer << " " << keyword << " ";
|
||||
printCompressedIntegerList(printer, coreIds);
|
||||
@@ -33,15 +107,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
|
||||
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " coreId " << getCoreId();
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<Type> weightTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (hasCoreId && result.attributes.get("coreId"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
|
||||
if (hasCoreId)
|
||||
result.addAttribute("coreId", getI32Attr(parser, coreId));
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
return parser.parseRegion(*body, weightArgs);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer << " = 0 to " << getLaneCount() << " ";
|
||||
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
||||
@@ -49,51 +184,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|
||||
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|
||||
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
if (parser.parseRegion(*body))
|
||||
return failure();
|
||||
|
||||
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|
||||
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|
||||
|| parser.parseLParen() || parser.parseRParen())
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -110,7 +251,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -14,6 +16,52 @@ namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (isa<PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return blockArgument.getParentRegion();
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp ? definingOp->getParentRegion() : nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|
||||
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||
}
|
||||
@@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimCoreOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != getWeights().size())
|
||||
return emitError("core body must have one block argument per weight");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||
}
|
||||
|
||||
LogicalResult PimCoreBatchOp::verify() {
|
||||
if (getLaneCount() <= 0)
|
||||
return emitError("laneCount must be positive");
|
||||
Block& block = getBody().front();
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("core_batch first block argument must have index type");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
@@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() {
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
return emitError("weight must be a shaped value");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
|
||||
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetMemRef.getType(),
|
||||
memCopyHostToDevOp.getDeviceTargetOffset(),
|
||||
memCopyHostToDevOp.getHostSourceOffset(),
|
||||
deviceTargetMemRef,
|
||||
hostSourceMemRef,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
hostTargetMemRef.getType(),
|
||||
memCopyDevToHostOp.getHostTargetOffset(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffset(),
|
||||
hostTargetMemRef,
|
||||
deviceSourceMemRef,
|
||||
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdAttr());
|
||||
sendOp.getTargetCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned weightIndex = bbArg.getArgNumber();
|
||||
return {
|
||||
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return {};
|
||||
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
unsigned operandIndex = argNumber - 1;
|
||||
if (argNumber > weightCount + 1)
|
||||
operandIndex = weightCount + (argNumber - 1 - weightCount);
|
||||
|
||||
return {
|
||||
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
||||
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return failure();
|
||||
|
||||
Value tiedOperand;
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
if (argNumber <= weightCount)
|
||||
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
|
||||
else
|
||||
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
|
||||
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
||||
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
||||
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
|
||||
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return WalkResult::skip();
|
||||
});
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
@@ -48,6 +49,13 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
for (Value result : user->getResults())
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||
if (initArg == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
for (OpResult result : user->getResults()) {
|
||||
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||
|
||||
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||
return getConstantI64(sendOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI64(receiveOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getSourceCoreId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getTargetCoreId());
|
||||
}
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
|
||||
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#define SPATIAL_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def SpatialDialect : Dialect {
|
||||
let name = "spat";
|
||||
@@ -22,7 +26,9 @@ def SpatTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
def SpatCompute : SpatOp<"compute",
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -36,14 +42,20 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
[SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
@@ -57,10 +69,41 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
::mlir::BlockArgument getOutputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||
Pure,
|
||||
Terminator,
|
||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||
HasParent<"SpatComputeBatch">,
|
||||
] # GraphRegionNoTerminator.traits> {
|
||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
||||
::mlir::OpResult getParentResult(int64_t idx);
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a compute region";
|
||||
|
||||
@@ -110,14 +153,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId,
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -125,9 +168,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -135,31 +178,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -167,44 +212,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -212,16 +263,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -229,7 +282,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -240,7 +295,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -251,7 +306,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -259,7 +314,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -270,7 +325,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,74 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
if (index == 0) {
|
||||
setNameFn(getOutputArgument(index), "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region* bodyRegion = result.addRegion();
|
||||
builder.createBlock(bodyRegion);
|
||||
}
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
|
||||
return getRegion().front().getOperations();
|
||||
}
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
printer << " channels ";
|
||||
printCompressedIntegerList(printer, channelIds);
|
||||
printer << " from ";
|
||||
printCompressedIntegerList(printer, sourceCoreIds);
|
||||
printer << " to ";
|
||||
printCompressedIntegerList(printer, targetCoreIds);
|
||||
}
|
||||
|
||||
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
|
||||
return parser.getBuilder().getDenseI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
@@ -47,94 +31,89 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
|
||||
ArrayRef<Type> weightTypes,
|
||||
ArrayRef<Type> outputTypes,
|
||||
OpAsmParser::Argument& laneArg,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
|
||||
Builder& builder) {
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
applyArgumentTypes(outputTypes, outputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
llvm::append_range(regionArgs, outputArgs);
|
||||
}
|
||||
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -243,9 +222,17 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
@@ -264,6 +251,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -272,10 +260,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
@@ -292,9 +281,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
@@ -313,19 +304,39 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index)
|
||||
outputArgs.push_back(getOutputArgument(index));
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
@@ -337,9 +348,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
@@ -350,7 +358,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -359,13 +372,20 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
@@ -381,10 +401,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -403,119 +428,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
auto& builder = parser.getBuilder();
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
if (parser.parseRegion(*region, regionArgs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
if (region->empty())
|
||||
OpBuilder(builder.getContext()).createBlock(region.get());
|
||||
result.addRegion(std::move(region));
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -81,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
@@ -104,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
|
||||
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
APInt constantValue;
|
||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||
}
|
||||
|
||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || isConstantIndexLike(value))
|
||||
return true;
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
return false;
|
||||
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
||||
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
||||
BlockArgument laneArg,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
RankedTensorType sourceType = sliceOp.getSourceType();
|
||||
RankedTensorType destType = sliceOp.getDestType();
|
||||
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.empty())
|
||||
if (channelCount == 0)
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -124,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult
|
||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
||||
if (channelCount != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult verifyTensorBatchChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -168,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
@@ -176,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return op->emitError("body must terminate with spat.yield");
|
||||
if (outputTypes.empty()) {
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
||||
}
|
||||
else {
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return op->emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
||||
return op->emitError("body yield type must match output type");
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
|
||||
<< " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
if (batchOp.getNumResults() == 0) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
||||
}
|
||||
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -205,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -220,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -338,8 +429,34 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -372,54 +489,59 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (getInputArgument(inputIndex).use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
@@ -429,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
if (getWeights().size() % laneCountSz != 0)
|
||||
return emitError("number of weights must be a multiple of laneCount");
|
||||
|
||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
||||
return emitError("number of inputs must be either 0 or laneCount");
|
||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
||||
return emitError("number of outputs must be either 0 or laneCount");
|
||||
|
||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
||||
Type weightType = getWeights()[weightIndex].getType();
|
||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
||||
return emitError("corresponding weights across lanes must have the same type");
|
||||
}
|
||||
|
||||
if (!getInputs().empty()) {
|
||||
Type inputType = getInputs()[0].getType();
|
||||
for (Value in : getInputs().drop_front())
|
||||
if (in.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
}
|
||||
|
||||
if (!getOutputs().empty()) {
|
||||
Type outputType = getOutputs()[0].getType();
|
||||
for (Value out : getOutputs().drop_front())
|
||||
if (out.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
@@ -465,27 +558,66 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch coreIds values must be positive");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
DenseSet<int32_t> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch coreIds values must be distinct");
|
||||
return emitError("compute_batch coreIds values must be unique");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (getInputs().empty()) {
|
||||
if (block.getNumArguments() != 0)
|
||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
else {
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("compute_batch body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
||||
return emitError("body block argument type must match input type");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
BlockArgument blockArg = getOutputArgument(resultIndex);
|
||||
if (blockArg.getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
}
|
||||
|
||||
LogicalResult SpatInParallelOp::verify() {
|
||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return emitOpError("expected spat.compute_batch parent");
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
for (OpOperand& destination : destinations)
|
||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -1,802 +1,19 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "../Scheduling/ComputeGraph.hpp"
|
||||
#include "../Scheduling/DcpScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||
SmallVector<ComputeInstance> instances;
|
||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return instances;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(computeInstances.size());
|
||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getComputeInstanceWeight(computeInstance);
|
||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
return graph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
if (!node.originalComputeIndices.empty())
|
||||
return node.originalComputeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto& neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto& task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph,
|
||||
std::vector<size_t>& oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto& mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||
memberNode.originalComputeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid) {
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
}
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalComputeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
size_t cpu = originalComputeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(computeInstance);
|
||||
result.computeToCpuMap[computeInstance] = cpu;
|
||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
||||
result.computeToAestMap[computeInstance] = originalIndex;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||
}
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult
|
||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
SmallVector<Weight> nodeWeights;
|
||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
SmallVector<int64_t> nodeOrderKeys;
|
||||
nodeWeights.reserve(computeInstances.size());
|
||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
||||
nodeOrderKeys.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = dyn_cast<SpatCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
instanceToIndex.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
||||
instanceToIndex[instance] = index;
|
||||
|
||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
||||
assert(producerIt != instanceToIndex.end());
|
||||
auto indexStartEdge = producerIt->second;
|
||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
||||
static_cast<int64_t>(indexEndEdge),
|
||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (coresCount.getValue() > 0) {
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
if (needsExactScheduledBatches)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
DcpScheduleOptions options;
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
|
||||
options.allowFallbackForAutoCoreCount = true;
|
||||
return runDcpScheduler(graph, options, entryOp->getContext());
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -2,37 +2,12 @@
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
||||
// spat.compute_batch.
|
||||
struct ComputeInstance {
|
||||
mlir::Operation* op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance& other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
#include "../Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
using DCPAnalysisResult = MergeScheduleResult;
|
||||
|
||||
struct DCPAnalysis {
|
||||
private:
|
||||
DCPAnalysisResult result;
|
||||
@@ -50,16 +25,4 @@ public:
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
||||
};
|
||||
} // namespace llvm
|
||||
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
class MergeScheduleMaterializer {
|
||||
public:
|
||||
mlir::LogicalResult
|
||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,743 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "PostMergeCompaction.hpp"
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
|
||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
~ScopedMergePhaseTimer() {
|
||||
if (!enabled)
|
||||
return;
|
||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||
}
|
||||
|
||||
private:
|
||||
bool enabled = false;
|
||||
std::string phase;
|
||||
std::chrono::steady_clock::time_point start;
|
||||
};
|
||||
|
||||
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct RebatchKey {
|
||||
unsigned inputCount = 0;
|
||||
unsigned resultCount = 0;
|
||||
unsigned weightCount = 0;
|
||||
uint64_t phase = 0;
|
||||
bool hasPhase = false;
|
||||
uint64_t structureHash = 0;
|
||||
|
||||
bool operator==(const RebatchKey& other) const {
|
||||
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||
}
|
||||
};
|
||||
|
||||
struct RebatchKeyInfo {
|
||||
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||
|
||||
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||
|
||||
static unsigned getHashValue(const RebatchKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||
}
|
||||
|
||||
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||
|
||||
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||
llvm::hash_code structureHash =
|
||||
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||
|
||||
Block& body = compute.getBody().front();
|
||||
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||
for (BlockArgument arg : body.getArguments())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||
|
||||
for (Operation& op : body) {
|
||||
structureHash = llvm::hash_combine(
|
||||
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||
for (Type type : op.getResultTypes())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||
for (NamedAttribute attr : op.getAttrs())
|
||||
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||
}
|
||||
|
||||
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||
static_cast<unsigned>(compute.getWeights().size()),
|
||||
phase.value_or(0),
|
||||
phase.has_value(),
|
||||
static_cast<uint64_t>(structureHash)};
|
||||
}
|
||||
|
||||
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
if (!lhs || !rhs)
|
||||
return false;
|
||||
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||
return false;
|
||||
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||
return false;
|
||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||
return false;
|
||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||
return false;
|
||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||
return false;
|
||||
|
||||
auto& lhsBlock = lhs.getBody().front();
|
||||
auto& rhsBlock = rhs.getBody().front();
|
||||
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||
return false;
|
||||
|
||||
DenseMap<Value, Value> mappedValues;
|
||||
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||
if (lhsArg.getType() != rhsArg.getType())
|
||||
return false;
|
||||
mappedValues[lhsArg] = rhsArg;
|
||||
}
|
||||
auto lhsIt = lhsBlock.begin();
|
||||
auto rhsIt = rhsBlock.begin();
|
||||
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||
Operation& lhsOp = *lhsIt;
|
||||
Operation& rhsOp = *rhsIt;
|
||||
|
||||
if (lhsOp.getName() != rhsOp.getName())
|
||||
return false;
|
||||
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||
return false;
|
||||
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||
return false;
|
||||
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||
return false;
|
||||
|
||||
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||
auto mapped = mappedValues.find(lhsOperand);
|
||||
if (mapped != mappedValues.end()) {
|
||||
if (mapped->second != rhsOperand)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (lhsOperand != rhsOperand)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||
return false;
|
||||
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||
mappedValues[lhsResult] = rhsResult;
|
||||
}
|
||||
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
struct BatchYieldInfo {
|
||||
Value yieldedValue;
|
||||
tensor::ParallelInsertSliceOp insertSlice;
|
||||
};
|
||||
|
||||
static bool isHostOnlyBatchResultUser(Operation* user) {
|
||||
return isa<func::ReturnOp,
|
||||
spatial::SpatConcatOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::CastOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp>(user);
|
||||
}
|
||||
|
||||
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
|
||||
Block& block = batchOp.getBody().front();
|
||||
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
|
||||
if (!inParallel)
|
||||
return failure();
|
||||
|
||||
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
|
||||
for (Operation& op : inParallel.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSlice)
|
||||
return failure();
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &block)
|
||||
return failure();
|
||||
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
|
||||
}
|
||||
return batchYieldByOutputArg;
|
||||
}
|
||||
|
||||
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return failure();
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
|
||||
batchOp.getWeights(),
|
||||
batchOp.getInputs());
|
||||
newBatch.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
|
||||
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
|
||||
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
|
||||
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
|
||||
}
|
||||
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
return newBatch;
|
||||
}
|
||||
|
||||
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
|
||||
|
||||
for (auto batchOp : batches) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
continue;
|
||||
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
|
||||
|
||||
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
|
||||
if (failed(batchYieldInfo))
|
||||
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
|
||||
|
||||
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
|
||||
if (failed(newBatch))
|
||||
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
Block& newBlock = newBatch->getBody().front();
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
|
||||
auto oldIt = oldBlock.begin();
|
||||
auto newIt = newBlock.begin();
|
||||
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
|
||||
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
|
||||
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
|
||||
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
|
||||
auto yieldInfoIt = batchYieldInfo->find(outputArg);
|
||||
if (yieldInfoIt == batchYieldInfo->end())
|
||||
return batchOp.emitOpError(
|
||||
"missing yielded value for compute_batch result during communication materialization");
|
||||
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
|
||||
|
||||
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
|
||||
SmallVector<OpOperand*> hostUses;
|
||||
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
|
||||
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
|
||||
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
|
||||
continue;
|
||||
}
|
||||
if (isHostOnlyBatchResultUser(use.getOwner())) {
|
||||
hostUses.push_back(&use);
|
||||
continue;
|
||||
}
|
||||
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
|
||||
<< ": " << use.getOwner()->getName();
|
||||
}
|
||||
|
||||
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
|
||||
if (uses.empty())
|
||||
return success();
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
channelIds.reserve(sourceCoreIds.size());
|
||||
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
|
||||
channelIds.push_back(nextChannelId++);
|
||||
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
sendChannelIdValues,
|
||||
sendSourceCoreIdValues,
|
||||
sendTargetCoreIdValues,
|
||||
mappedYieldedValue);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfter(newBatch->getOperation());
|
||||
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
batchOp.getResult(resultIndex).getType(),
|
||||
receiveChannelIdValues,
|
||||
receiveSourceCoreIdValues,
|
||||
receiveTargetCoreIdValues);
|
||||
for (OpOperand* use : uses)
|
||||
use->set(received.getOutput());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
return success();
|
||||
};
|
||||
|
||||
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
|
||||
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
|
||||
if (failed(createReceiveForUses(uses, targetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!hostUses.empty()) {
|
||||
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
|
||||
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseSet<Operation*> consumed;
|
||||
DenseMap<Operation*, size_t> computeOrder;
|
||||
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
computeOrder[compute.getOperation()] = index;
|
||||
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
auto anchor = computes[index];
|
||||
if (consumed.contains(anchor))
|
||||
continue;
|
||||
if (anchor.getInputs().size() > 1)
|
||||
continue;
|
||||
if (!anchor.getResults().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<SpatCompute> group {anchor};
|
||||
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||
if (auto coreId = getComputeCoreId(anchor))
|
||||
usedCoreIds.insert(*coreId);
|
||||
|
||||
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||
if (bucketIt == candidatesByKey.end())
|
||||
continue;
|
||||
|
||||
for (auto candidate : bucketIt->second) {
|
||||
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||
continue;
|
||||
if (consumed.contains(candidate))
|
||||
continue;
|
||||
if (!areEquivalentForRebatch(anchor, candidate))
|
||||
continue;
|
||||
|
||||
if (auto coreId = getComputeCoreId(candidate))
|
||||
if (!usedCoreIds.insert(*coreId).second)
|
||||
continue;
|
||||
|
||||
group.push_back(candidate);
|
||||
}
|
||||
|
||||
if (group.size() <= 1)
|
||||
continue;
|
||||
|
||||
auto insertionAnchor = group.front();
|
||||
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||
llvm::stable_sort(
|
||||
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||
}
|
||||
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(group.size() * anchor.getWeights().size());
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(group.size());
|
||||
bool haveAllCoreIds = true;
|
||||
for (auto compute : group) {
|
||||
llvm::append_range(weights, compute.getWeights());
|
||||
llvm::append_range(inputs, compute.getInputs());
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
haveAllCoreIds = false;
|
||||
else if (haveAllCoreIds)
|
||||
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(insertionAnchor);
|
||||
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||
insertionAnchor.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (haveAllCoreIds)
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& anchorBlock = anchor.getBody().front();
|
||||
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||
for (Operation& anchorOp : anchorBlock) {
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||
BatchReceiveEntry entry;
|
||||
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchReceiveEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getOutput().getType(),
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues);
|
||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||
BatchSendEntry entry;
|
||||
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchSendEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues,
|
||||
mapper.lookup(sendOp.getInput()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
}
|
||||
|
||||
for (auto compute : group) {
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
consumed.insert(compute);
|
||||
rewriter.eraseOp(compute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
}
|
||||
|
||||
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
{
|
||||
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||
orderBilateralChannelOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||
rebatchEquivalentComputes(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||
compactRegularOpRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||
compactRowWiseWvmmRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
|
||||
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,14 +3,18 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -27,7 +31,7 @@ enum class RegularStepKind {
|
||||
|
||||
struct RegularStep {
|
||||
RegularStepKind kind;
|
||||
int32_t weightIndex = 0;
|
||||
Value weight;
|
||||
Value invariantOperand;
|
||||
Type resultType;
|
||||
};
|
||||
@@ -40,6 +44,122 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
struct RegularCompactionResult {
|
||||
bool changed = false;
|
||||
Operation* resumeAfter = nullptr;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct ConsecutiveRun {
|
||||
SmallVector<OpTy> ops;
|
||||
Block::iterator end;
|
||||
};
|
||||
|
||||
template <typename OpTy, typename Predicate>
|
||||
static ConsecutiveRun<OpTy>
|
||||
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
|
||||
ConsecutiveRun<OpTy> run;
|
||||
run.end = start;
|
||||
while (run.end != blockEnd) {
|
||||
auto current = dyn_cast<OpTy>(&*run.end);
|
||||
if (!current || !predicate(current))
|
||||
break;
|
||||
run.ops.push_back(current);
|
||||
++run.end;
|
||||
}
|
||||
return run;
|
||||
}
|
||||
|
||||
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Operation*> getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) {
|
||||
SmallVector<Operation*> defs;
|
||||
defs.reserve(metadataOperandCount);
|
||||
for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) {
|
||||
Operation* def = channelOp->getOperand(operandIndex).getDefiningOp();
|
||||
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(def);
|
||||
if (!constantOp || def->getBlock() != channelOp->getBlock())
|
||||
continue;
|
||||
defs.push_back(def);
|
||||
}
|
||||
llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); });
|
||||
return defs;
|
||||
}
|
||||
|
||||
static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) {
|
||||
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||
metadataDef->moveBefore(insertionPoint);
|
||||
channelOp->moveBefore(insertionPoint);
|
||||
}
|
||||
|
||||
static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) {
|
||||
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||
metadataDef->moveBefore(block, insertionPoint);
|
||||
channelOp->moveBefore(block, insertionPoint);
|
||||
}
|
||||
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
if (values.empty() || !values.front().hasOneUse())
|
||||
return {};
|
||||
@@ -152,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||
return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand
|
||||
&& lhs.resultType == rhs.resultType;
|
||||
}
|
||||
|
||||
@@ -166,14 +286,24 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static bool isForwardedChannelPayload(Value value, Block& block) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op || op->getBlock() != &block)
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
|
||||
|
||||
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
chunk.input = startOp.getInput();
|
||||
chunk.output = startOp.getOutput();
|
||||
chunk.ops.push_back(startOp.getOperation());
|
||||
chunk.steps.push_back(
|
||||
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()});
|
||||
|
||||
Value currentValue = startOp.getOutput();
|
||||
while (currentValue.hasOneUse()) {
|
||||
@@ -186,9 +316,9 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
break;
|
||||
|
||||
if (vaddOp.getLhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||
else if (vaddOp.getRhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||
else
|
||||
break;
|
||||
|
||||
@@ -200,9 +330,11 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
static RegularCompactionResult
|
||||
compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run, OperationFolder& constantFolder) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
RegularCompactionResult result;
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
@@ -212,16 +344,16 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||
if (!packedInput)
|
||||
return;
|
||||
return result;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
||||
auto packedInit = tensor::EmptyOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
|
||||
auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder);
|
||||
auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast<int64_t>(run.size()), constantFolder);
|
||||
auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder);
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
@@ -234,8 +366,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
|
||||
Value inputRowOffset = iv;
|
||||
if (inputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
|
||||
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder);
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
@@ -264,8 +395,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
|
||||
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder);
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
@@ -315,41 +445,141 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
llvm::append_range(opsToErase, chunk.ops);
|
||||
for (Operation* op : llvm::reverse(opsToErase))
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
result.changed = true;
|
||||
result.resumeAfter = loop.getOperation()->getNextNode();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
continue;
|
||||
|
||||
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||
Block& block = compute.getBody().front();
|
||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||
Operation* firstForwardedSend = nullptr;
|
||||
|
||||
for (Operation& op : block) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId)
|
||||
&& sourceCoreId == static_cast<uint32_t>(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||
if (!firstForwardedSend)
|
||||
firstForwardedSend = sendOp.getOperation();
|
||||
uint64_t key = getEndpointKey(sourceCoreId, targetCoreId);
|
||||
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||
|| targetCoreId != static_cast<uint32_t>(coreId) || sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), sourceCoreId);
|
||||
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||
else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp))
|
||||
moves.push_back({receiveOp, firstForwardedSend});
|
||||
}
|
||||
|
||||
for (auto [receiveOp, insertionPoint] : moves)
|
||||
moveScalarChannelBundleBefore(receiveOp, insertionPoint);
|
||||
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||
|| sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
uint64_t currentChannelId = 0;
|
||||
uint32_t currentSourceCoreId = 0;
|
||||
uint32_t currentTargetCoreId = 0;
|
||||
return current.getOutput().getType() == outputType
|
||||
&& getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId)
|
||||
&& currentSourceCoreId < static_cast<uint32_t>(coreId);
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||
uint64_t lhsChannelId = 0;
|
||||
uint32_t lhsSourceCoreId = 0;
|
||||
uint32_t lhsTargetCoreId = 0;
|
||||
uint64_t rhsChannelId = 0;
|
||||
uint32_t rhsSourceCoreId = 0;
|
||||
uint32_t rhsTargetCoreId = 0;
|
||||
bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId);
|
||||
bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId);
|
||||
if (!lhsHasMetadata || !rhsHasMetadata)
|
||||
return false;
|
||||
return lhsSourceCoreId > rhsSourceCoreId;
|
||||
});
|
||||
Block::iterator insertIt = run.end;
|
||||
for (auto op : sorted)
|
||||
moveScalarChannelBundleBefore(op, &block, insertIt);
|
||||
}
|
||||
|
||||
it = run.end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
bool hasRepeatedEndpoint = false;
|
||||
for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) {
|
||||
for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) {
|
||||
if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId()
|
||||
&& run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) {
|
||||
DenseSet<uint64_t> seenEndpoints;
|
||||
for (auto op : run.ops) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
}
|
||||
uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId);
|
||||
if (!seenEndpoints.insert(endpointKey).second) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (run.size() > 1 && !hasRepeatedEndpoint) {
|
||||
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
@@ -358,9 +588,21 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops)) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
sortedEntries.clear();
|
||||
break;
|
||||
}
|
||||
sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId});
|
||||
}
|
||||
if (sortedEntries.empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -369,13 +611,12 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
SmallVector<Value> sortedOutputs;
|
||||
sortedOutputs.reserve(sortedEntries.size());
|
||||
@@ -388,14 +629,12 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(
|
||||
rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues);
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(concatOp,
|
||||
concatStartIndex,
|
||||
@@ -408,7 +647,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -419,18 +658,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run =
|
||||
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
struct SendEntry {
|
||||
spatial::SpatChannelSendOp op;
|
||||
uint32_t sourceCoreId = 0;
|
||||
@@ -438,9 +672,21 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto op : run.ops) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
sortedEntries.clear();
|
||||
break;
|
||||
}
|
||||
sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId});
|
||||
}
|
||||
if (sortedEntries.empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -451,26 +697,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendTensorOp::create(
|
||||
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -489,32 +733,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
for (auto op : run) {
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<Value> channelIds;
|
||||
SmallVector<Value> sourceCoreIds;
|
||||
SmallVector<Value> targetCoreIds;
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (auto op : run)
|
||||
outputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
outputs.push_back(op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
@@ -523,24 +762,19 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(
|
||||
rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds);
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
for (auto [index, op] : llvm::enumerate(run.ops))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -551,43 +785,34 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<Value> channelIds;
|
||||
SmallVector<Value> sourceCoreIds;
|
||||
SmallVector<Value> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
inputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
inputs.push_back(op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
spatial::SpatChannelSendTensorBatchOp::create(
|
||||
rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput);
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -600,6 +825,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
|
||||
void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
auto compactInBlock = [&](Block& block) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
@@ -615,8 +841,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
auto runIt = anchorEndIt;
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
@@ -631,12 +858,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
if (run.size() <= 1) {
|
||||
++it;
|
||||
it = anchorEndIt;
|
||||
continue;
|
||||
}
|
||||
|
||||
compactRegularChunkRun(rewriter, run);
|
||||
it = runIt;
|
||||
size_t originalOpCount = 0;
|
||||
for (const RegularChunk& chunk : run)
|
||||
originalOpCount += chunk.ops.size();
|
||||
|
||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder);
|
||||
if (result.changed) {
|
||||
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||
if (!result.resumeAfter) {
|
||||
it = block.end();
|
||||
continue;
|
||||
}
|
||||
it = result.resumeAfter->getIterator();
|
||||
continue;
|
||||
}
|
||||
|
||||
it = anchorEndIt;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -648,6 +889,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
|
||||
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
@@ -667,37 +909,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||
if (current.getWeight() != wvmmOp.getWeight()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
||||
break;
|
||||
}
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||
return false;
|
||||
|
||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||
break;
|
||||
return false;
|
||||
|
||||
run.push_back(current);
|
||||
++expectedRow;
|
||||
++runIt;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
if (run.size() <= 1) {
|
||||
if (run.ops.size() <= 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!run.front().getOutput().hasOneUse()) {
|
||||
if (!run.ops.front().getOutput().hasOneUse()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
auto concatUse = run.front().getOutput().getUses().begin();
|
||||
auto concatUse = run.ops.front().getOutput().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||
if (!concatOp) {
|
||||
++it;
|
||||
@@ -706,7 +943,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||
bool validConcatRun = true;
|
||||
for (auto [index, op] : llvm::enumerate(run)) {
|
||||
for (auto [index, op] : llvm::enumerate(run.ops)) {
|
||||
if (!op.getOutput().hasOneUse()) {
|
||||
validConcatRun = false;
|
||||
break;
|
||||
@@ -737,17 +974,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
int64_t runLength = static_cast<int64_t>(run.size());
|
||||
int64_t runLength = static_cast<int64_t>(run.ops.size());
|
||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder);
|
||||
auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder);
|
||||
auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder);
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
@@ -758,41 +995,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
Value sourceRow = iv;
|
||||
if (firstRow != 0) {
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
||||
auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
inputType,
|
||||
extractRowsOp.getInput(),
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
||||
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
SmallVector<Value> newConcatInputs;
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == concatStartIndex)
|
||||
newConcatInputs.push_back(loop.getResult(0));
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
|
||||
newConcatInputs.push_back(operand);
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = loop->getIterator();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (const ComputeGraphEdge& edge : edges) {
|
||||
if (edge.source == edge.target)
|
||||
continue;
|
||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (const auto& [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source < rhs.source;
|
||||
return lhs.target < rhs.target;
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
ComputeGraph graph;
|
||||
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
|
||||
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto producerInstance = getComputeProducerInstance(input, &node.instance);
|
||||
if (!producerInstance)
|
||||
continue;
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.successors.assign(graph.nodes.size(), {});
|
||||
graph.predecessors.assign(graph.nodes.size(), {});
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
bool verifyAcyclic(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readyNodes;
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyNodes.push(node);
|
||||
}
|
||||
|
||||
size_t visited = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t node = readyNodes.front();
|
||||
readyNodes.pop();
|
||||
++visited;
|
||||
for (const auto& [child, weight] : graph.successors[node]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
return visited == graph.nodes.size();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../DCPGraph/Utils.hpp"
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeGraphNode {
|
||||
ComputeInstance instance;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraphEdge {
|
||||
size_t source = 0;
|
||||
size_t target = 0;
|
||||
Weight transferCost = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraph {
|
||||
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
};
|
||||
|
||||
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph &graph);
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeInstance {
|
||||
mlir::Operation *op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance &other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||
}
|
||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
+193
@@ -0,0 +1,193 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return static_cast<size_t>(laneCount);
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
|
||||
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
|
||||
return {batch.getOperation(), lane, 1};
|
||||
}
|
||||
|
||||
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
|
||||
if (extract.getMixedOffsets().empty())
|
||||
return std::nullopt;
|
||||
|
||||
OpFoldResult offset = extract.getMixedOffsets().front();
|
||||
if (Attribute attr = llvm::dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!intAttr || intAttr.getInt() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(intAttr.getInt());
|
||||
}
|
||||
|
||||
Value offsetValue = llvm::cast<Value>(offset);
|
||||
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||
if (constantIndex.value() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(constantIndex.value());
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<ProducerValueRef> getResultfulBatchProducerValueRef(SpatComputeBatch batch,
|
||||
const ComputeInstance* consumerInstance) {
|
||||
if (!consumerInstance)
|
||||
return std::nullopt;
|
||||
if (!isa<SpatComputeBatch>(consumerInstance->op))
|
||||
return std::nullopt;
|
||||
if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast<uint32_t>(batch.getLaneCount()))
|
||||
return std::nullopt;
|
||||
return ProducerValueRef {
|
||||
{batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount},
|
||||
0
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(Value value, const ComputeInstance* consumerInstance) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
Value source = extract.getSource();
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
|
||||
if (batch && batch.getNumResults() != 0) {
|
||||
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
|
||||
if (*lane >= static_cast<uint32_t>(batch.getLaneCount()))
|
||||
return std::nullopt;
|
||||
return ProducerValueRef {
|
||||
{batch.getOperation(), *lane, 1},
|
||||
0
|
||||
};
|
||||
}
|
||||
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||
}
|
||||
|
||||
value = source;
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
return ProducerValueRef {
|
||||
ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
if (batch.getNumResults() != 0)
|
||||
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||
size_t resultIndex = lane - instance.laneStart;
|
||||
return ProducerValueRef {instance, resultIndex};
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) {
|
||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value, consumerInstance))
|
||||
return producer->instance;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getInputs().begin(), batch.getInputs().end());
|
||||
|
||||
assert(batch.getInputs().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||
&& "resultless compute_batch inputs must be evenly partitioned by lane");
|
||||
size_t inputsPerLane = batch.getInputs().size() / static_cast<size_t>(batch.getLaneCount());
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount * inputsPerLane);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||
size_t firstInput = static_cast<size_t>(lane) * inputsPerLane;
|
||||
inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getWeights().begin(), batch.getWeights().end());
|
||||
|
||||
assert(batch.getWeights().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||
&& "resultless compute_batch weights must be evenly partitioned by lane");
|
||||
size_t weightsPerLane = batch.getWeights().size() / static_cast<size_t>(batch.getLaneCount());
|
||||
llvm::SmallVector<Value, 4> weights;
|
||||
weights.reserve(instance.laneCount * weightsPerLane);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||
size_t firstWeight = static_cast<size_t>(lane) * weightsPerLane;
|
||||
weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane);
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getResults().begin(), batch.getResults().end());
|
||||
|
||||
llvm::SmallVector<Value, 4> outputs;
|
||||
outputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getOutputs().empty())
|
||||
outputs.push_back(batch.getOutputs()[lane]);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance) {
|
||||
llvm::SmallVector<Type, 4> outputTypes;
|
||||
for (Value output : getComputeInstanceOutputValues(instance))
|
||||
outputTypes.push_back(output.getType());
|
||||
return outputTypes;
|
||||
}
|
||||
|
||||
Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return compute.getBody().front();
|
||||
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+41
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ProducerValueRef {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget();
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,720 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "../DCPGraph/Graph.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
llvm::SmallVector<size_t, 4> originalNodeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||
if (options.processorCount > 0)
|
||||
return options.processorCount;
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||
VirtualGraph virtualGraph;
|
||||
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||
VirtualNode virtualNode;
|
||||
virtualNode.originalNodeIndices.push_back(index);
|
||||
virtualNode.weight = node.weight;
|
||||
virtualNode.crossbarUsage = node.crossbarUsage;
|
||||
virtualGraph.nodes.push_back(std::move(virtualNode));
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> edges;
|
||||
edges.reserve(graph.edges.size());
|
||||
for (const ComputeGraphEdge &edge : graph.edges)
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
virtualGraph.edges = aggregateEdges(edges);
|
||||
return virtualGraph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||
if (!node.originalNodeIndices.empty())
|
||||
return node.originalNodeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto &neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<size_t> selectedNodes,
|
||||
const DcpScheduleOptions &options,
|
||||
mlir::MLIRContext *context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto &task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph &coarsenedGraph,
|
||||
std::vector<size_t> &oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto &mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
nodeIndexByInstance[node.instance] = nodeIndex;
|
||||
|
||||
struct ScheduledEdge {
|
||||
size_t target = 0;
|
||||
Time delay = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
|
||||
|
||||
Time delay = graph.nodes[edge.source].weight;
|
||||
if (sourceCpu != targetCpu)
|
||||
delay = addOrMax(delay, edge.transferCost);
|
||||
|
||||
scheduledChildren[edge.source].push_back({edge.target, delay});
|
||||
incomingEdgeCount[edge.target]++;
|
||||
}
|
||||
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
|
||||
size_t sourceIndex = scheduledTasks[i - 1].second;
|
||||
size_t targetIndex = scheduledTasks[i].second;
|
||||
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
|
||||
incomingEdgeCount[targetIndex]++;
|
||||
}
|
||||
}
|
||||
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
|
||||
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
|
||||
if (incomingEdgeCount[nodeIndex] == 0)
|
||||
readyNodes.push(nodeIndex);
|
||||
|
||||
std::vector<Time> startTimes(graph.nodes.size(), 0);
|
||||
size_t processedNodeCount = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t sourceIndex = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
processedNodeCount++;
|
||||
|
||||
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||
incomingEdgeCount[edge.target]--;
|
||||
if (incomingEdgeCount[edge.target] == 0)
|
||||
readyNodes.push(edge.target);
|
||||
}
|
||||
}
|
||||
|
||||
if (processedNodeCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
|
||||
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||
MergeScheduleResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid)
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||
originalNodeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
|
||||
size_t cpu = originalNodeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
result.computeToCpuMap[node.instance] = cpu;
|
||||
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||
}
|
||||
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
assignFeasibleAest(originalGraph, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||
for (const ComputeGraphNode &node : graph.nodes)
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
|
||||
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
llvm::SmallVector<Weight> nodeWeights;
|
||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||
llvm::SmallVector<IndexedEdge> edges;
|
||||
nodeWeights.reserve(graph.nodes.size());
|
||||
nodeCrossbarUsage.reserve(graph.nodes.size());
|
||||
nodeOrderKeys.reserve(graph.nodes.size());
|
||||
edges.reserve(graph.edges.size());
|
||||
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
nodeWeights.push_back(node.weight);
|
||||
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||
}
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||
}
|
||||
|
||||
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||
return false;
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
if (needsExactScheduledBatches(graph, options))
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
if (options.criticalWindowSize == 0)
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
llvm::SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, graph);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct DcpScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
size_t criticalWindowSize = 0;
|
||||
bool allowFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct MergeScheduleResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+139
@@ -0,0 +1,139 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "MergeSchedulingAnalysis.hpp"
|
||||
#include "PeftScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
MergeSchedulerKind getSchedulerKind() {
|
||||
switch (pimMergeScheduler.getValue()) {
|
||||
case MergeSchedulerPeft:
|
||||
return MergeSchedulerKind::Peft;
|
||||
case MergeSchedulerDcp:
|
||||
return MergeSchedulerKind::Dcp;
|
||||
}
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
|
||||
if (!result.computeToCpuMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
|
||||
if (!result.computeToCpuSlotMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
|
||||
if (!result.computeToAestMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing start time");
|
||||
|
||||
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
|
||||
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
CrossbarUsage usedCrossbars = 0;
|
||||
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||
if (scheduledTasks[slot].first != slot)
|
||||
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||
if (usedCrossbars > crossbarCapacity)
|
||||
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||
}
|
||||
|
||||
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
|
||||
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
|
||||
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
|
||||
if (!result.isLastComputeOfCpu.count(expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||
}
|
||||
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(target);
|
||||
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
|
||||
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
|
||||
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
|
||||
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
|
||||
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
if (targetStart < earliestTargetStart) {
|
||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||
graph.nodes[edge.source].originalOrder,
|
||||
graph.nodes[edge.target].originalOrder)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
|
||||
MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
verifyExplicitPimCoreCount();
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
if (!verifyAcyclic(graph))
|
||||
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
|
||||
|
||||
MergeSchedulingOptions options;
|
||||
options.kind = getSchedulerKind();
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
|
||||
MergeScheduleResult schedule;
|
||||
if (options.kind == MergeSchedulerKind::Peft) {
|
||||
schedule = runPeftScheduler(
|
||||
graph,
|
||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
else {
|
||||
schedule = runDcpScheduler(
|
||||
graph,
|
||||
DcpScheduleOptions {
|
||||
options.processorCount,
|
||||
dcpCriticalWindowSize.getValue(),
|
||||
options.allowDcpFallbackForAutoCoreCount
|
||||
},
|
||||
entryOp->getContext());
|
||||
}
|
||||
|
||||
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||
return schedule;
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
enum class MergeSchedulerKind {
|
||||
Dcp,
|
||||
Peft,
|
||||
};
|
||||
|
||||
struct MergeSchedulingOptions {
|
||||
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
|
||||
size_t processorCount = 0;
|
||||
bool allowDcpFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
class MergeSchedulingAnalysis {
|
||||
public:
|
||||
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||
MergeScheduleResult &getResult() { return result; }
|
||||
|
||||
private:
|
||||
mlir::Operation *entryOp = nullptr;
|
||||
MergeScheduleResult result;
|
||||
|
||||
MergeScheduleResult run();
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,362 @@
|
||||
#include "mlir/IR/Threading.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "PeftScheduler.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ScheduledTask {
|
||||
size_t processor = std::numeric_limits<size_t>::max();
|
||||
Time startTime = 0;
|
||||
Time endTime = 0;
|
||||
size_t slot = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readySinks;
|
||||
std::vector<std::vector<size_t>> reverseLevels;
|
||||
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingSuccessors[node] = graph.successors[node].size();
|
||||
if (remainingSuccessors[node] == 0)
|
||||
readySinks.push(node);
|
||||
}
|
||||
|
||||
size_t levelizedCount = 0;
|
||||
while (!readySinks.empty()) {
|
||||
size_t levelSize = readySinks.size();
|
||||
std::vector<size_t> levelNodes;
|
||||
levelNodes.reserve(levelSize);
|
||||
for (size_t i = 0; i < levelSize; ++i) {
|
||||
size_t node = readySinks.front();
|
||||
readySinks.pop();
|
||||
levelNodes.push_back(node);
|
||||
++levelizedCount;
|
||||
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||
if (--remainingSuccessors[pred] == 0)
|
||||
readySinks.push(pred);
|
||||
}
|
||||
}
|
||||
reverseLevels.push_back(std::move(levelNodes));
|
||||
}
|
||||
|
||||
if (levelizedCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
|
||||
|
||||
return reverseLevels;
|
||||
}
|
||||
|
||||
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
||||
if (nodeCount == 0 || processorCount == 0)
|
||||
return;
|
||||
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t rowBytes = processorCount * sizeof(Time);
|
||||
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t totalBytes = nodeCount * rowBytes;
|
||||
if (totalBytes > kMaxOctTableBytes) {
|
||||
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
|
||||
totalBytes / (1024 * 1024))
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||
const size_t nodeCount = graph.nodes.size();
|
||||
const size_t processorCount = options.processorCount;
|
||||
if (processorCount == 0)
|
||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||
|
||||
verifyOctTableSize(nodeCount, processorCount);
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
|
||||
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||
auto computeNodeOct = [&](size_t levelIndex) {
|
||||
size_t task = levelNodes[levelIndex];
|
||||
std::vector<Time> maxVals(processorCount, 0);
|
||||
|
||||
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||
}
|
||||
}
|
||||
|
||||
Time minForPreds = std::numeric_limits<Time>::max();
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
oct[task * processorCount + processor] = maxVals[processor];
|
||||
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||
}
|
||||
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
|
||||
else
|
||||
for (size_t i = 0; i < levelNodes.size(); ++i)
|
||||
computeNodeOct(i);
|
||||
}
|
||||
|
||||
struct RankEntry {
|
||||
long double rank = 0.0L;
|
||||
size_t node = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
std::vector<RankEntry> ranks(nodeCount);
|
||||
auto computeRank = [&](size_t node) {
|
||||
long double rank = 0.0L;
|
||||
for (size_t processor = 0; processor < processorCount; ++processor)
|
||||
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||
else
|
||||
for (size_t node = 0; node < nodeCount; ++node)
|
||||
computeRank(node);
|
||||
|
||||
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||
const RankEntry& lhsRank = ranks[lhs];
|
||||
const RankEntry& rhsRank = ranks[rhs];
|
||||
if (lhsRank.rank != rhsRank.rank)
|
||||
return lhsRank.rank < rhsRank.rank;
|
||||
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||
return lhsRank.originalOrder > rhsRank.originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
|
||||
std::vector<int> remainingParents(nodeCount, 0);
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
|
||||
for (size_t node = 0; node < nodeCount; ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyQueue.push(node);
|
||||
}
|
||||
|
||||
std::vector<char> scheduled(nodeCount, false);
|
||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||
std::vector<ScheduledTask> schedules(nodeCount);
|
||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||
|
||||
size_t scheduledCount = 0;
|
||||
while (!readyQueue.empty()) {
|
||||
size_t task = readyQueue.top();
|
||||
readyQueue.pop();
|
||||
if (scheduled[task])
|
||||
continue;
|
||||
|
||||
size_t bestProcessor = std::numeric_limits<size_t>::max();
|
||||
Time bestEst = 0;
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
bool crossbarRejected = false;
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
if (graph.nodes[task].crossbarUsage != 0
|
||||
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
const ScheduledTask& predSchedule = schedules[pred];
|
||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||
}
|
||||
|
||||
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||
Time compWeight = getComputeCost(task, processor);
|
||||
Time est = dataReady;
|
||||
Time currentEnd = 0;
|
||||
bool foundGap = false;
|
||||
|
||||
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||
Time gapStart = std::max(currentEnd, dataReady);
|
||||
|
||||
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||
est = gapStart;
|
||||
foundGap = true;
|
||||
break;
|
||||
}
|
||||
currentEnd = schedTask.endTime;
|
||||
}
|
||||
|
||||
if (!foundGap)
|
||||
est = std::max(currentEnd, dataReady);
|
||||
|
||||
Time eft = addOrMax(est, compWeight);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
}
|
||||
}
|
||||
|
||||
if (bestProcessor == std::numeric_limits<size_t>::max()) {
|
||||
if (crossbarRejected) {
|
||||
std::string message =
|
||||
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
|
||||
graph.nodes[task].originalOrder,
|
||||
options.crossbarCapacity)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
|
||||
graph.nodes[task].originalOrder,
|
||||
processorCount)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||
scheduled[task] = true;
|
||||
++scheduledCount;
|
||||
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||
|
||||
// 3. CRITICAL FIX: Topological Append
|
||||
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||
tasksByProcessor[bestProcessor].push_back(task);
|
||||
|
||||
for (const auto& [child, weight] : graph.successors[task]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyQueue.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduledCount != nodeCount)
|
||||
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||
|
||||
// 4. Build Strict Topological Dominance Order
|
||||
std::vector<size_t> scheduledOrder(nodeCount);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
scheduledOrder[i] = i;
|
||||
|
||||
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||
});
|
||||
|
||||
// 5. Check if equal schedule in two level
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
|
||||
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) {
|
||||
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
|
||||
continue;
|
||||
auto& currentTasks = tasksByProcessor[currentProcessor];
|
||||
auto& controlTasks = tasksByProcessor[controlProcessor];
|
||||
bool equalSchedule = true;
|
||||
|
||||
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
|
||||
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
|
||||
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
|
||||
if (currentComputeInstance.op != controlComputeInstance.op) {
|
||||
equalSchedule = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (equalSchedule) {
|
||||
equivalentClass[currentProcessor].push_back(controlProcessor);
|
||||
equivalentClass[controlProcessor].push_back(currentProcessor);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||
std::vector<bool> visited(processorCount, false);
|
||||
size_t uniqueClassCount = 0;
|
||||
|
||||
for (size_t i = 0; i < processorCount; ++i) {
|
||||
if (visited[i])
|
||||
continue;
|
||||
|
||||
// We found a new unique schedule (equivalence class)
|
||||
++uniqueClassCount;
|
||||
visited[i] = true;
|
||||
|
||||
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||
|
||||
// Find and mark all identical companions
|
||||
auto it = equivalentClass.find(i);
|
||||
if (it != equivalentClass.end()) {
|
||||
for (size_t eqCpu : it->second) {
|
||||
if (!visited[eqCpu]) {
|
||||
llvm::dbgs() << ", " << eqCpu;
|
||||
visited[eqCpu] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
llvm::dbgs() << " }\n";
|
||||
}
|
||||
|
||||
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||
llvm::dbgs() << "--------------------------------------\n";
|
||||
}
|
||||
|
||||
// 6. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(nodeCount);
|
||||
|
||||
for (size_t task : scheduledOrder)
|
||||
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
size_t currentSlot = 0;
|
||||
for (size_t task : tasksByProcessor[processor]) {
|
||||
const ComputeInstance instance = graph.nodes[task].instance;
|
||||
result.computeToCpuMap[instance] = processor;
|
||||
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||
}
|
||||
if (!tasksByProcessor[processor].empty()) {
|
||||
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
|
||||
result.cpuToLastComputeMap[processor] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
}
|
||||
|
||||
result.equivalentClass = equivalentClass;
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct PeftScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
CrossbarUsage crossbarCapacity = 0;
|
||||
mlir::MLIRContext *context = nullptr;
|
||||
};
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -24,9 +24,12 @@ struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
|
||||
createDirectory(pimDir);
|
||||
|
||||
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
|
||||
if (compiler_error_code != CompilerSuccess)
|
||||
if (compiler_error_code != CompilerSuccess) {
|
||||
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
|
||||
<< compiler_error_code;
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
|
||||
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim3_folded");
|
||||
dumpModule(moduleOp, "pim3_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
|
||||
@@ -268,24 +268,31 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getDeviceTarget(),
|
||||
copyOp.getHostSource(),
|
||||
copyOp.getDeviceTargetOffset(),
|
||||
copyOp.getHostSourceOffset(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
|
||||
Value srcOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffsetValue,
|
||||
srcOffsetValue,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
@@ -301,24 +308,31 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getHostTarget(),
|
||||
copyOp.getDeviceSource(),
|
||||
copyOp.getHostTargetOffset(),
|
||||
copyOp.getDeviceSourceOffset(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/false,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
|
||||
Value srcOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffset,
|
||||
srcOffset,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
@@ -355,9 +369,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType =
|
||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultMemRefType = cast<MemRefType>(subviewOp.getType());
|
||||
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
@@ -23,11 +23,11 @@ namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
return operandIndex == 3;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -39,7 +39,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
@@ -48,6 +51,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
@@ -105,14 +111,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
||||
.getOutput();
|
||||
}
|
||||
else {
|
||||
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
||||
copiedValue =
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
getOrCreateHostIndexConstant(op, 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
.getOutput();
|
||||
}
|
||||
@@ -134,6 +141,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
OperationFolder constantFolder(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
@@ -141,10 +149,10 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
|
||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
@@ -160,6 +168,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -96,17 +98,49 @@ static bool isConstantGlobalView(Value value) {
|
||||
value = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
return operandIndex == 3;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isCoreWeightBlockArgument(Value value) {
|
||||
auto blockArgument = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArgument)
|
||||
return false;
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(blockArgument.getOwner()->getParentOp()))
|
||||
return static_cast<unsigned>(blockArgument.getArgNumber()) < coreOp.getWeights().size();
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(blockArgument.getOwner()->getParentOp())) {
|
||||
unsigned argNumber = static_cast<unsigned>(blockArgument.getArgNumber());
|
||||
return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -152,14 +186,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
hasFailure = true;
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
});
|
||||
});
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
@@ -168,49 +203,58 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
|
||||
(void) verifyCoreOperands(coreOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
(void) withScalarCoreFromBatchLane(
|
||||
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); });
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyReturnOp(returnOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
(void) verifyAddressOnlyHostOp(&op, diagnostics);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure)
|
||||
if (diagnostics.hasFailure()) {
|
||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
||||
static LogicalResult
|
||||
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (auto it : llvm::enumerate(coreOp.getWeights())) {
|
||||
size_t weightIndex = it.index();
|
||||
Value weight = it.value();
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before JSON "
|
||||
"codegen";
|
||||
<< " must be materialized as a constant memref.global or a static view of one before "
|
||||
"JSON codegen";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
@@ -220,14 +264,18 @@ private:
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -235,11 +283,15 @@ private:
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
size_t resultIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -247,38 +299,73 @@ private:
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
if (!isSupportedCoreInstructionOp(&op)) {
|
||||
op.emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
for (auto it : llvm::enumerate(op.getOperands())) {
|
||||
size_t operandIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
if (isCoreWeightBlockArgument(operand))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with "
|
||||
"pim.memcp_hd";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
||||
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -286,18 +373,20 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource());
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -305,19 +394,24 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
static LogicalResult
|
||||
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isBaseAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user