Compare commits

25 Commits

Author SHA1 Message Date
NiccoloN
bdacb9871d fix dcp merge bug
Some checks failed
Validate Operations / validate-operations (push) Failing after 15m54s
2026-05-04 15:58:14 +02:00
NiccoloN
5b9bb0c191 refactor spatial ops
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m55s
2026-05-04 14:19:30 +02:00
NiccoloN
f789954ad7 Refactor ONNXToSpatial Common and diagnostics 2026-05-04 13:42:43 +02:00
ilgeco
b6ba1e4fea Fix DCPTest using old constructor
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m15s
2026-05-04 10:58:51 +02:00
NiccoloN
717ad160cd Refactor PIM/Common (splitting in files, adding helpers, adding brief
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
2026-05-04 09:20:43 +02:00
NiccoloN
905fa9f9a7 Merge remote changes
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m42s
2026-05-03 23:09:32 +02:00
NiccoloN
62b0a6e19d merge remote changes 2026-05-03 22:30:46 +02:00
NiccoloN
b605585b1f compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
2026-05-03 14:14:14 +02:00
ilgeco
08b0fcd850 Parallel bufferization
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m49s
2026-04-30 11:48:17 +02:00
ilgeco
9dccc2c701 Translate global constant to symble 2026-04-28 12:42:01 +02:00
ilgeco
5c839e62c1 Func Input converted to symbol 2026-04-27 13:48:03 +02:00
NiccoloN
15e8edb9c4 better spat computes merging
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m14s
2026-04-25 19:24:09 +02:00
ilgeco
951baca106 Merge Node update fix comparison bug
All checks were successful
Validate Operations / validate-operations (push) Successful in 20m21s
2026-04-23 19:52:16 +02:00
ilgeco
fc5bccb487 Merge Node update status file
Some checks are pending
Validate Operations / validate-operations (push) Has started running
2026-04-23 19:42:56 +02:00
ilgeco
49dea15b95 DCP Merge status
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m29s
2026-04-23 18:40:33 +02:00
NiccoloN
5545b0f672 fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
2026-04-23 14:44:30 +02:00
NiccoloN
cff929a083 fix sigmoid implementation stability in pim-simulator
All checks were successful
Validate Operations / validate-operations (push) Successful in 23m4s
2026-04-23 10:34:29 +02:00
NiccoloN
89b3501aa8 fix weightAlways attribute in spatial 2026-04-23 10:04:47 +02:00
NiccoloN
412ca957f6 multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
2026-04-23 09:28:57 +02:00
NiccoloN
0f13269040 faster DCPAnalysis on partial graph
All checks were successful
Validate Operations / validate-operations (push) Successful in 27m37s
2026-04-21 18:36:16 +02:00
NiccoloN
dafc1d15b7 faster pim-simulator 2026-04-21 18:35:51 +02:00
NiccoloN
3fa140be25 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-04-21 16:23:16 +02:00
ilgeco
df703f0be9 pim-simulator add progress report
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m24s
2026-04-21 16:23:03 +02:00
NiccoloN
9fa850c140 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-04-21 15:59:08 +02:00
NiccoloN
25ade1bd63 fix memory allocation in pim codegen
fix crossbar allocation to only consider weights from vmm and mvm
2026-04-21 13:31:10 +02:00
86 changed files with 9323 additions and 3088 deletions

10
.gitignore vendored
View File

@@ -1,5 +1,15 @@
.zed
.idea .idea
**/.vscode **/.vscode
.claude .claude
.codex
AGENTS.md AGENTS.md
CMakeUserPresets.json
build build
cmake-build-debug
cmake-build-release
**/__*

154
README.md
View File

@@ -1,5 +1,159 @@
# Raptor # Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
```
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` — materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
- `--maccel=PIM` — select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
## Rebuilding
Release build (fast):
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
## Build ## Build
### Protobuf ### Protobuf

View File

@@ -55,17 +55,25 @@ pub trait HasSigm {
impl HasSigm for f32 { impl HasSigm for f32 {
fn sigm(self) -> Self { fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp(); let ex = self.exp();
ex / (1.0 + ex) ex / (1.0 + ex)
} }
} }
}
impl HasSigm for f64 { impl HasSigm for f64 {
fn sigm(self) -> Self { fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp(); let ex = self.exp();
ex / (1.0 + ex) ex / (1.0 + ex)
} }
} }
}
pub trait HasExp { pub trait HasExp {
fn exp(self) -> Self; fn exp(self) -> Self;

View File

@@ -1,50 +1,54 @@
#![allow(unused)] #![allow(unused)]
use std::time::{Duration, SystemTime};
use crate::{ use crate::{
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER cpu::CPU,
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
memory_manager::type_traits::TryToUsize,
send_recv::{SendRecv, handle_send_recv},
tracing::TRACER,
}; };
pub mod cpu; pub mod cpu;
pub mod instruction_set; pub mod instruction_set;
pub mod json_to_instruction;
pub mod memory_manager; pub mod memory_manager;
pub mod send_recv; pub mod send_recv;
pub mod utility;
pub mod json_to_instruction;
pub mod tracing; pub mod tracing;
pub mod utility;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CoreInstructionsBuilder { pub struct CoreInstructionsBuilder {
core_instructions : Vec<CoreInstruction> core_instructions: Vec<CoreInstructions>,
} }
impl CoreInstructionsBuilder { impl CoreInstructionsBuilder {
pub fn new(size: usize) -> Self { pub fn new(size: usize) -> Self {
let mut core_instructions = Vec::with_capacity(size); let mut core_instructions = Vec::with_capacity(size);
for _ in 0..=size { for _ in 0..=size {
core_instructions.push(CoreInstruction::empty()); core_instructions.push(CoreInstructions::empty());
} }
Self { core_instructions } Self { core_instructions }
} }
pub fn build(self) -> Vec<CoreInstruction> { pub fn build(self) -> Vec<CoreInstructions> {
self.core_instructions self.core_instructions
} }
pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self { pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self {
self.core_instructions[core.try_into().expect("Set core with not valid size")] = core_instruction.into(); self.core_instructions[core.try_into().expect("Set core with not valid size")] =
core_instruction.into();
self self
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CoreInstruction { pub struct CoreInstructions {
instructions: Instructions, instructions: Instructions,
program_counter: usize, program_counter: usize,
} }
impl CoreInstruction { impl CoreInstructions {
fn new(instructions: Instructions, program_counter: usize) -> Self { fn new(instructions: Instructions, program_counter: usize) -> Self {
Self { Self {
instructions, instructions,
@@ -53,13 +57,16 @@ impl CoreInstruction {
} }
fn empty() -> Self { fn empty() -> Self {
Self { instructions: Vec::new(), program_counter: 0 } Self {
instructions: Vec::new(),
program_counter: 0,
}
} }
} }
impl From<Instructions> for CoreInstruction { impl From<Instructions> for CoreInstructions {
fn from(value: Instructions) -> Self { fn from(value: Instructions) -> Self {
CoreInstruction { CoreInstructions {
instructions: value, instructions: value,
program_counter: 0, program_counter: 0,
} }
@@ -69,39 +76,62 @@ impl From<Instructions> for CoreInstruction {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Executable<'a> { pub struct Executable<'a> {
cpu: CPU<'a>, cpu: CPU<'a>,
core_instructions: Vec<CoreInstruction>, core_instructions: Vec<CoreInstructions>,
send_recv: SendRecv, send_recv: SendRecv,
} }
fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0;
let mut progress = 0;
for core_instruction in core_instructions.iter() {
tot_instructions += core_instruction.instructions.len();
progress += core_instruction.program_counter;
}
println!(
"Progress: {}% ({}/{}) ",
progress as f32 / tot_instructions as f32 * 100.0,
progress,
tot_instructions
);
}
impl<'a> Executable<'a> { impl<'a> Executable<'a> {
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstruction>) -> Executable<'a> { pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstructions>) -> Executable<'a> {
let num_core = cpu.num_core(); let num_core = cpu.num_core();
let send_recv = SendRecv::new(num_core); let send_recv = SendRecv::new(num_core);
assert_eq!(num_core, core_instructions.len(), "Some core doesn't have is list of istruction (required even if empty)"); assert_eq!(
num_core,
core_instructions.len(),
"Some core doesn't have is list of istruction (required even if empty)"
);
Self { Self {
cpu, cpu,
core_instructions, core_instructions,
send_recv send_recv,
} }
} }
pub fn execute<'b>(&'b mut self) pub fn execute<'b>(&'b mut self)
where 'a : 'b where
'a: 'b,
{ {
let Self { let Self {
cpu, cpu,
core_instructions, core_instructions: cores_instructions,
send_recv send_recv,
} = self; } = self;
let mut cpu_progressed = 0; let mut cpu_progressed = 0;
let max_core = cpu.num_core(); let max_core = cpu.num_core();
let mut index_unit = 0; let mut cpu_index = 0;
let mut now = SystemTime::now();
while (cpu_progressed > -2) { while (cpu_progressed > -2) {
let mut core_result = InstructionStatus::Completed; let mut core_result = InstructionStatus::Completed;
while core_result.is_completed() && let Some(core_instruction) = core_instructions.get_mut(index_unit){ while core_result.is_completed()
&& let Some(core_instruction) = cores_instructions.get_mut(cpu_index)
{
core_result = InstructionStatus::NotExecuted; core_result = InstructionStatus::NotExecuted;
let CoreInstruction { let CoreInstructions {
instructions, instructions,
program_counter, program_counter,
} = core_instruction; } = core_instruction;
@@ -114,16 +144,31 @@ impl<'a> Executable<'a> {
cpu_progressed = 0; cpu_progressed = 0;
*program_counter += 1; *program_counter += 1;
} }
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
print_status(&cores_instructions);
now = SystemTime::now();
} }
if handle_send_recv(cpu, core_instructions, send_recv, core_result) { cpu_progressed = 0; } }
handle_wait_sync(cpu, core_instructions, core_result); handle_wait_sync(cpu, cores_instructions, core_result);
index_unit = if index_unit + 1 >= max_core { match handle_send_recv(cpu, cores_instructions, send_recv, core_result) {
(true, other_cpu_index) => {
cpu_progressed = 0;
cpu_index = other_cpu_index;
}
(false, 0) => {
cpu_index = if cpu_index + 1 >= cores_instructions.len() {
cpu_progressed -= 1; cpu_progressed -= 1;
0 0
} else { } else {
index_unit + 1 cpu_index + 1
}; };
} }
(false, other_cpu_index) => {
cpu_index = other_cpu_index;
}
}
}
print_status(cores_instructions);
} }
pub fn cpu(&self) -> &CPU<'a> { pub fn cpu(&self) -> &CPU<'a> {
@@ -145,13 +190,12 @@ impl<'a> Executable<'a> {
} }
} }
fn handle_wait_sync<'a, 'b, 'c >(cpu: &'b mut CPU<'a>, core_instructions: &'c mut [CoreInstruction], core_result: InstructionStatus) fn handle_wait_sync<'a, 'b, 'c>(
where 'a : 'b, cpu: &'b mut CPU<'a>,
'a : 'c core_instructions: &'c mut [CoreInstructions],
core_result: InstructionStatus,
) where
'a: 'b,
'a: 'c,
{ {
} }

View File

@@ -1,7 +1,7 @@
use anyhow::Context; use anyhow::Context;
use crate::{ use crate::{
CoreInstruction, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER, CoreInstructions, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
utility::add_offset_rd, utility::add_offset_rd,
}; };
@@ -43,14 +43,14 @@ impl SendRecv {
pub fn handle_send_recv<'a, 'b >( pub fn handle_send_recv<'a, 'b >(
cpu: &'b mut CPU<'a>, cpu: &'b mut CPU<'a>,
core_instructions: & mut [CoreInstruction], core_instructions: & mut [CoreInstructions],
send_recv: & mut SendRecv, send_recv: & mut SendRecv,
core_result: InstructionStatus, core_result: InstructionStatus,
) -> bool ) -> (bool, usize)
where 'a : 'b where 'a : 'b
{ {
let transfer_memory = |cpu: &'b mut CPU<'a>, let transfer_memory = |cpu: &'b mut CPU<'a>,
core_instructions: & mut [CoreInstruction], core_instructions: & mut [CoreInstructions],
sender: Option<SendRecvInfo>, sender: Option<SendRecvInfo>,
receiver: Option<SendRecvInfo>| { receiver: Option<SendRecvInfo>| {
if let Some(sender) = sender if let Some(sender) = sender
@@ -119,7 +119,7 @@ where 'a : 'b
send_recv.sending[sender] = None; send_recv.sending[sender] = None;
send_recv.receiving[receiver] = None; send_recv.receiving[receiver] = None;
} }
transfered (transfered, receiver)
} }
InstructionStatus::Reciving(instruction_data) => { InstructionStatus::Reciving(instruction_data) => {
let (core_idx, imm_core) = instruction_data.get_core_immcore(); let (core_idx, imm_core) = instruction_data.get_core_immcore();
@@ -148,8 +148,8 @@ where 'a : 'b
send_recv.sending[sender] = None; send_recv.sending[sender] = None;
send_recv.receiving[receiver] = None; send_recv.receiving[receiver] = None;
} }
transfered (transfered, sender)
} }
_ => false, _ => (false, 0),
} }
} }

View File

@@ -1,11 +1,10 @@
mod tracing_isa; mod tracing_isa;
mod disable; mod disable;
mod pretty_print; mod pretty_print;
#[cfg(feature = "tracing")]
use std::{fs::File, path::{ PathBuf}}; use std::{fs::File, path::{ PathBuf}};
use std::sync::{LazyLock, Mutex}; use std::sync::{LazyLock, Mutex};
use crate::Executable; use crate::Executable;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]

View File

@@ -1,5 +1,12 @@
add_pim_library(OMPimCommon add_pim_library(OMPimCommon
PimCommon.cpp IR/AddressAnalysis.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
IR/WeightUtils.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -0,0 +1,258 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
}
namespace {
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (mlir::isa<mlir::BlockArgument>(value))
return value;
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs + *rhs;
}
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs - *rhs;
}
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs * *rhs;
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return mlir::failure();
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!integerAttr)
return mlir::failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return mlir::failure();
offsets.push_back(*resolvedOffset);
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return mlir::failure();
sizes.push_back(*resolvedSize);
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return mlir::failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,43 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
/// byte offset after peeling aliases, casts, and contiguous subviews.
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
/// Records compile-time facts used when interpreting address arithmetic and
/// loop-carried aliases inside PIM regions.
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
} // namespace onnx_mlir

View File

@@ -0,0 +1,67 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
}
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,24 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
/// Returns true for ops in a `pim.core` body that only participate in static
/// address or index computation and therefore do not emit PIM instructions.
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
/// their bounds are known and invoking `callback` only on instruction-emitting
/// operations.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir

View File

@@ -0,0 +1,45 @@
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
if (!moduleOp)
return mlir::failure();
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return mlir::failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return mlir::failure();
}
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return mlir::failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
return mainGraphFunc;
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return mlir::failure();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
/// Resolves the function the PIM pipeline should treat as its entry point.
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
/// non-external function if the module is otherwise unambiguous.
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir

View File

@@ -0,0 +1,89 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,22 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir

View File

@@ -0,0 +1,101 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
});
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,29 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
} // namespace onnx_mlir

View File

@@ -1,546 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir

View File

@@ -7,82 +7,21 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {
struct ResolvedContiguousAddress { inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
mlir::Value base;
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,27 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
moduleOp.print(os, flags);
os.flush();
file.close();
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include <string>
namespace onnx_mlir {
/// Emits a MLIR snapshot under the current compiler output
/// directory for pass-level debugging.
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
} // namespace onnx_mlir

View File

@@ -0,0 +1,41 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
return op->emitOpError() << "requires statically shaped " << valueDescription;
}
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks) {
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
if (supportedRanks.empty())
return diag;
diag << "; supported rank";
if (supportedRanks.size() != 1)
diag << 's';
diag << ' ';
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
return diag;
}
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
}
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
llvm::StringRef action,
llvm::StringRef path,
const std::error_code& errorCode) {
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
return mlir::failure();
}
} // namespace onnx_mlir::pim

View File

@@ -0,0 +1,38 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <system_error>
namespace onnx_mlir::pim {
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
/// accepted by the current lowering/codegen path.
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks);
/// Emits a consistent diagnostic for missing symbol/global references.
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
/// the relevant IR location.
mlir::LogicalResult
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
template <typename T>
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
return mlir::success(succeeded(value));
}
} // namespace onnx_mlir::pim

View File

@@ -0,0 +1,24 @@
#include <filesystem>
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,13 @@
#pragma once
#include <string>
namespace onnx_mlir {
/// Returns the directory that should hold PIM artifacts/debug dumps for the
/// current compiler invocation.
std::string getOutputDir();
void createDirectory(const std::string& directory);
} // namespace onnx_mlir

View File

@@ -1,10 +1,15 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@@ -16,7 +21,7 @@
#include <utility> #include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -33,6 +38,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
}
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress; memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size; firstAvailableAddress += memEntry.size;
@@ -44,35 +55,49 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
} }
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// More than one SSA value per single global constant: SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others SmallVector<mlir::Value> args;
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
for (mlir::Value arg : funcOp.getArguments()){
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp); if (globalMemrefOp.getName().starts_with("arg")){
if (iter == globalConstants.end()) StringRef indexStr = globalMemrefOp.getName().substr(4);
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); int index = 0;
else { llvm::to_integer(indexStr,index, 10);
MemEntry memEntry = *iter->second; globalAliases.push_back({getGlobalOp.getResult(), args[index]});
globalMemEntriesMap[getGlobalOp] = memEntry;
} }
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted)
gatherMemEntry(getGlobalOp.getResult());
else
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
} }
}); });
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
allocateCore(funcOp); funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
allocateGatheredMemory();
for (auto [alias, original] : globalAliases)
globalMemEntriesMap[alias] = getMemEntry(original);
} }
void PimMemory::allocateCore(Operation* op) { void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); allocateGatheredMemory();
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
} }
MemEntry PimMemory::getMemEntry(mlir::Value value) const { MemEntry PimMemory::getMemEntry(mlir::Value value) const {
@@ -122,6 +147,12 @@ json::Object PimCodeGen::createEmptyOffset() {
return offset; return offset;
} }
size_t PimCodeGen::remapCoreId(size_t coreId) const {
auto it = emittedCoreIds.find(coreId);
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
return it->second;
}
static json::Object createRs1OnlyOffset() { static json::Object createRs1OnlyOffset() {
json::Object offset; json::Object offset;
offset["offset_select"] = 1; offset["offset_select"] = 1;
@@ -181,7 +212,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
json::Object json; json::Object json;
json["op"] = opName; json["op"] = opName;
json["rd"] = 0; json["rd"] = 0;
json["core"] = coreId; json["core"] = remapCoreId(coreId);
json["size"] = size; json["size"] = size;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
emitInstruction(std::move(json)); emitInstruction(std::move(json));
@@ -403,6 +434,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
@@ -465,6 +499,136 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes"; return std::to_string(size) + " Bytes";
} }
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName);
assert(coreIdsAttr && "pim.core_batch requires core_id array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front()) {
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
}
return coreLikeOps;
}
static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
OpBuilder builder(coreBatchOp);
builder.setInsertionPointAfter(coreBatchOp);
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<mlir::Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(builder,
coreBatchOp.getLoc(),
ValueRange(laneWeights),
builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op)) {
pim::PimHaltOp::create(builder, op.getLoc());
continue;
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveOp::create(builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
hostSource = memcpBatchOp.getHostSource();
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
hostSource,
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return scalarCore;
}
static void aliasMaterializedHostGlobals(
ModuleOp moduleOp, func::FuncOp funcOp, pim::PimCoreOp coreOp, PimAcceleratorMemory& memory) {
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal)
return;
mlir::Value aliasedValue;
funcOp.walk([&](memref::GetGlobalOp candidate) {
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
return;
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
aliasedValue = candidate.getResult();
});
if (aliasedValue)
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
});
}
/// Write global constant data into a binary memory image at their allocated addresses. /// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
@@ -478,12 +642,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp)) if (hasWeightAlways(getGlobalOp))
return; return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) if (!globalOp)
return; return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue(); auto initialValue = globalOp.getInitialValue();
if (!initialValue) if (!initialValue)
return; return;
@@ -556,6 +723,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op)) else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else { else {
op.emitError("Unsupported codegen for this operation"); op.emitError("Unsupported codegen for this operation");
op.dump(); op.dump();
@@ -645,7 +814,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
return CompilerSuccess; return CompilerSuccess;
} }
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>(); ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights"; auto coreWeightsDirPath = outputDirPath + "/weights";
@@ -654,11 +823,30 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
size_t indexFileName = 0; size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue(); int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName; llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName; llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) { SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
for (Operation* op : coreLikeOps) {
SmallVector<pim::PimCoreOp> scalarCores;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
scalarCores.push_back(coreOp);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
}
for (pim::PimCoreOp coreOp : scalarCores) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>(); auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) { if (!getGlobalOp) {
@@ -687,7 +875,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
if (mapGlobalOpToFileName.contains(globalOp)) { if (mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp]; auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName}; std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreOp].insert(weightToFile); mapCoreWeightToFileName[coreId].insert(weightToFile);
continue; continue;
} }
@@ -726,22 +914,28 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
weightFileStream.close(); weightFileStream.close();
mapGlobalOpToFileName.insert({globalOp, newFileName}); mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreOp].insert({weight, newFileName}); mapCoreWeightToFileName[coreId].insert({weight, newFileName});
} }
} }
for (pim::PimCoreOp coreOp : scalarCores)
if (coreOp.getOperation() != op)
coreOp.erase();
}
return mapCoreWeightToFileName; return mapCoreWeightToFileName;
} }
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses). /// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory, PimAcceleratorMemory& memory,
size_t coreCount, size_t maxCoreId,
json::Object xbarsPerArrayGroup, json::Object xbarsPerArrayGroup,
StringRef outputDirPath) { StringRef outputDirPath) {
json::Object configJson; json::Object configJson;
// +1 because pimsim-nn also considers the host as a core // pimsim-nn indexes cores directly by their numeric core ID, with the host
configJson["core_cnt"] = coreCount + 1; // occupying core 0.
configJson["core_cnt"] = maxCoreId + 1;
// TODO: Should this be based on the floating point type used in the model? // TODO: Should this be based on the floating point type used in the model?
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision // The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
@@ -815,14 +1009,47 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
// For each core, specify the number of crossbar per array group. // For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group. // This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup; json::Object xbarsPerArrayGroup;
size_t coreCount = 0; size_t maxCoreId = 0;
// Create Weight Folder // Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) { SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
auto coreId = coreOp.getCoreId(); llvm::DenseMap<size_t, size_t> emittedCoreIds;
coreCount++; size_t nextEmittedCoreId = 1;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
}
}
for (Operation* op : coreLikeOps) {
SmallVector<pim::PimCoreOp> scalarCores;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
scalarCores.push_back(coreOp);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
}
for (pim::PimCoreOp coreOp : scalarCores) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
size_t coreId = emittedCoreIds.lookup(originalCoreId);
maxCoreId = std::max(maxCoreId, coreId);
std::error_code errorCode; std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
@@ -833,7 +1060,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
} }
coreFileStream << '['; coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream); PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
@@ -841,29 +1069,32 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
return CompilerFailure; return CompilerFailure;
assert(processedOperations > 0); assert(processedOperations > 0);
// Remove trailing comma, close JSON array
coreFileStream.seek(coreFileStream.tell() - 1); coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']'; coreFileStream << ']';
coreFileStream.close(); coreFileStream.close();
// Write crossbar weights for this core
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp]; auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
json::Array xbarsPerGroup; json::Array xbarsPerGroup;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index); xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight]; auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message() << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:"
<< '\n'; << error.message() << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
} }
@@ -871,5 +1102,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
} }
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath); for (pim::PimCoreOp coreOp : scalarCores)
if (coreOp.getOperation() != op)
coreOp.erase();
}
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
} }

View File

@@ -1,5 +1,6 @@
#pragma once #pragma once
#include "llvm/ADT/DenseMap.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
@@ -24,6 +25,7 @@ class PimMemory {
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value); MemEntry* gatherMemEntry(mlir::Value value);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public: public:
@@ -57,10 +59,12 @@ public:
class PimCodeGen { class PimCodeGen {
PimAcceleratorMemory& memory; PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream; llvm::raw_fd_ostream& coreFileStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge); return memory.getValueAddress(value, knowledge);
} }
size_t remapCoreId(size_t coreId) const;
static llvm::json::Object createEmptyOffset(); static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const; void emitInstruction(llvm::json::Object instruction) const;
@@ -82,8 +86,10 @@ class PimCodeGen {
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public: public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson) PimCodeGen(PimAcceleratorMemory& memory,
: memory(memory), coreFileStream(coreJson) {} llvm::raw_fd_ostream& coreJson,
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
@@ -105,6 +111,7 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
}; };

View File

@@ -41,12 +41,18 @@ llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2)); crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count", llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::init(-1)); llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(4000));
llvm::cl::opt<bool> llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error", ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),

View File

@@ -29,6 +29,7 @@ extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount; extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
// This option, by default set to false, will ignore an error when resolving a // This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the // specific tiles of the operands of a concat. This specific case is when the

View File

@@ -18,7 +18,9 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common.cpp Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -1,279 +0,0 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir

View File

@@ -0,0 +1,8 @@
#pragma once
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"

View File

@@ -0,0 +1,39 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,153 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir

View File

@@ -1,24 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert> #include "ShapeTilingUtils.hpp"
#include <optional>
#include <utility>
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -107,31 +93,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
return tensor::SplatOp::create(rewriter, loc, type, elementValue); return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) { } // namespace onnx_mlir
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir

View File

@@ -0,0 +1,143 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
}
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir

View File

@@ -0,0 +1,114 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "WeightMaterialization.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -8,7 +9,6 @@
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@@ -18,14 +18,12 @@
#include <iterator> #include <iterator>
#include <utility> #include <utility>
#include "Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,8 +31,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace { namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
@@ -51,13 +47,49 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private: private:
void annotateWeightsConstants(func::FuncOp funcOp) const; void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp); LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
}; };
} // namespace } // namespace
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<spatial::SpatComputeBatch> batchOps;
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
for (auto batchOp : batchOps) {
if (batchOp.getLaneCount() != 1)
continue;
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock = rewriter.createBlock(
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp.replaceAllUsesWith(computeOp.getResults());
rewriter.eraseOp(batchOp);
}
}
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
@@ -87,8 +119,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect, tensor::TensorDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>( target.addIllegalOp<ONNXMatMulOp>();
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>(); target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>(); target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>(); target.addIllegalOp<ONNXMulOp>();
@@ -129,11 +160,13 @@ void ONNXToSpatialPass::runOnOperation() {
return; return;
} }
foldSingleLaneComputeBatches(*entryFunc);
// Count the number of compute ops and check they do not exceed the core count // Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : entryFunc->getFunctionBody().front().getOperations()) for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op)) if (isa<spatial::SpatCompute>(op))
computeOpsCount++; computeOpsCount++;
if (computeOpsCount > coresCount) { if (computeOpsCount > coresCount) {
@@ -149,15 +182,17 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure();
return;
}
if (failed(promoteConstantInputsToWeights(*entryFunc))) { if (failed(promoteConstantInputsToWeights(*entryFunc))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
mergeTriviallyConnectedComputes(*entryFunc);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
} }
@@ -167,19 +202,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) { if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp); Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB); rewriter.setInsertionPointToEnd(BB);
IRMapping mapper; IRMapping mapper;
mapper.map(source, BB->getArgument(0)); mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper); auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0)); spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute); inst->replaceAllUsesWith(newCompute->getResults());
inst->erase(); inst->erase();
return true; return true;
} }
return false;
}
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
auto source = toRemoveOp.getSource();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
} }
return false; return false;
} }
@@ -188,9 +240,32 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) { if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs(); auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of( if (llvm::any_of(sources,
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) { [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = spatial::SpatConcatOp::create(rewriter,
loc,
toRemoveOp.getType(),
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
ValueRange(BB->getArguments()));
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes; SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc; SmallVector<Location> sourceLoc;
for (auto source : sources) { for (auto source : sources) {
@@ -204,82 +279,115 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg); mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper); auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute); inst->replaceAllUsesWith(newCompute->getResults());
inst->erase(); inst->erase();
return true; return true;
} }
}
return false; return false;
} }
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
if (auto mapped = mapper.lookupOrNull(value)) if (op == nullptr)
return cast<Value>(mapped); return false;
Operation* definingOp = value.getDefiningOp(); Operation* source = nullptr;
if (!definingOp) do {
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) { if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType()); return false;
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
} }
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp)) auto tmpSource = extractSliceOp.getSource();
return failure(); auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
IRMapping localMapper; op = definingOp;
for (Value operand : definingOp->getOperands()) { else
if (auto mapped = mapper.lookupOrNull(operand)) { return false;
localMapper.map(operand, cast<Value>(mapped));
continue;
} }
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
if (isWeightLikeComputeOperand(operand)) { auto tmpSource = extractRowsOp.getInput();
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); auto definingOp = tmpSource.getDefiningOp();
if (failed(clonedOperand)) if (definingOp)
return failure(); op = definingOp;
localMapper.map(operand, *clonedOperand); else
continue; return false;
} }
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
localMapper.map(operand, operand); auto tmpSource = expandShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
} }
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
Operation* clonedOp = rewriter.clone(*definingOp, localMapper); auto tmpSource = transposeOp.getData();
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) auto definingOp = tmpSource.getDefiningOp();
mapper.map(oldResult, newResult); if (definingOp)
op = definingOp;
auto mapped = mapper.lookupOrNull(value); else
if (!mapped) return false;
}
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
auto tmpSource = collapseShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
source = constantOp;
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else {
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
return failure(); return failure();
return cast<Value>(mapped); }
}
while (source == nullptr);
return hasWeightAlways(source);
} }
// TODO what we want to keep in global? // TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
bool keep = true; bool keep = true;
while (keep) { while (keep) {
keep = false; keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>( if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); instruction)
|| isa<func::ReturnOp>(instruction))
continue;
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
if (failed(weightBacked))
return failure();
if (*weightBacked)
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
keep |= encapsulator<tensor::ExpandShapeOp>( keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
@@ -293,108 +401,19 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
keep |= encapsulateConcat(rewriter, loc, &instruction); keep |= encapsulateConcat(rewriter, loc, &instruction);
} }
} }
} return success();
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (compute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
if (user && user.getInputs().size() == 1)
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
if (user && user.getInputs().size() == 1)
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
compute.getResult(0).dropAllUses();
compute.erase();
}
} }
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp); markWeightAlways(constantOp);
}); });
} }
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>()); SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
for (auto compute : computes) { for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false); SmallVector<bool> promoteInput(compute.getInputs().size(), false);
@@ -430,7 +449,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
} }
auto newCompute = auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock = auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(

View File

@@ -7,11 +7,10 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -147,10 +146,11 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
} }
static Value createIm2colCompute(Value x, static Value createIm2colRowComputes(Value x,
RankedTensorType xType, RankedTensorType xType,
RankedTensorType im2colType, RankedTensorType im2colType,
RankedTensorType rowType, RankedTensorType im2colRowType,
RankedTensorType gemmInputRowsType,
int64_t batchSize, int64_t batchSize,
int64_t numChannelsIn, int64_t numChannelsIn,
int64_t xHeight, int64_t xHeight,
@@ -169,11 +169,14 @@ static Value createIm2colCompute(Value x,
int64_t patchSize, int64_t patchSize,
int64_t numPatches, int64_t numPatches,
int64_t numPatchesPerBatch, int64_t numPatchesPerBatch,
int64_t packFactor,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) { const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
Value paddedInput = xArg; Value paddedInput = xArg;
// Pad input with zeros if needed: // Pad input with zeros if needed:
@@ -240,7 +243,7 @@ static Value createIm2colCompute(Value x,
Value row = tensor::CollapseShapeOp::create(rewriter, Value row = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
rowType, im2colRowType,
patch, patch,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0}, {0},
@@ -256,28 +259,13 @@ static Value createIm2colCompute(Value x,
rewriter.setInsertionPointAfter(im2colLoop); rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0); Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
return im2colComputeOp.getResult(0);
}
static Value createPackedIm2colRows(Value im2col, Value gemmInputRows = im2col;
RankedTensorType im2colType, if (packFactor != 1) {
Type elemType,
int64_t numPatches,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return im2col;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor; const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) { Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
groupedType, groupedType,
@@ -286,7 +274,7 @@ static Value createPackedIm2colRows(Value im2col,
{0, 1}, {0, 1},
{2} {2}
}); });
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter, gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
packedType, packedType,
groupedIm2col, groupedIm2col,
@@ -294,31 +282,39 @@ static Value createPackedIm2colRows(Value im2col,
{0}, {0},
{1, 2} {1, 2}
}); });
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
});
return packedComputeOp.getResult(0);
} }
static Value createUnpackedOutput(Value packedOutput, spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
return im2colComputeOp.getResult(0);
}
static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType,
RankedTensorType gemmOutType, RankedTensorType gemmOutType,
RankedTensorType nhwcType,
RankedTensorType outType, RankedTensorType outType,
int64_t numPatches, int64_t numPatches,
int64_t numChannelsOut, int64_t numChannelsOut,
int64_t packFactor, int64_t packFactor,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor; const int64_t paddedNumPatches = packedNumRows * packFactor;
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) { Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
expandedType, expandedType,
packedOutputArg, packedOutput,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0}, {0},
{1, 2} {1, 2}
@@ -332,30 +328,15 @@ static Value createUnpackedOutput(Value packedOutput,
{2} {2}
}); });
Value unpackedOutput = paddedOutput; gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) { if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput = gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
} }
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
});
return unpackComputeOp.getResult(0);
} }
static Value createCollectedConvOutput(Value gemmOut,
Type convType,
RankedTensorType nhwcType,
RankedTensorType outType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto collectComputeOp =
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
Value gemmOutArg = gemmOutArgs.front();
// Restore to NCHW layout: // Restore to NCHW layout:
// [numPatches, numChannelsOut] // [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut]
@@ -363,7 +344,7 @@ static Value createCollectedConvOutput(Value gemmOut,
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
nhwcType, nhwcType,
gemmOutArg, gemmOut,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0, 1, 2}, {0, 1, 2},
{3} {3}
@@ -388,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType()); auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType()); auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); if (!xType.hasStaticShape()) {
assert("Only support 2D convolution" && xType.getRank() == 4); pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
// We need to understand what is group }
assert("Only support group=1" && convOp.getGroup() == 1); if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0); const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1); const int64_t numChannelsIn = xType.getDimSize(1);
@@ -409,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const auto dilationsAttr = convOp.getDilations(); const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads(); const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -449,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd; padWidthBegin = totalPadW - padWidthEnd;
} }
} }
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0 // "NOTSET" or "VALID" -> all pads stay 0
} }
@@ -487,11 +508,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// Pass bias through directly; Gemm handles rank-1 C canonicalization. // Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp()); bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix; Value biasMatrix;
DenseElementsAttr biasDenseAttr; DenseElementsAttr biasDenseAttr;
if (hasB) { if (hasB) {
gemmC = b; gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b); biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc); biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
} }
@@ -500,10 +521,26 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t effectiveMaxParallelPixels = const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x, // Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowsType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value gemmInputRows = createIm2colRowComputes(x,
xType, xType,
im2colType, im2colType,
rowType, rowType,
gemmInputRowsType,
batchSize, batchSize,
numChannelsIn, numChannelsIn,
xHeight, xHeight,
@@ -522,44 +559,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
patchSize, patchSize,
numPatches, numPatches,
numPatchesPerBatch, numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter, rewriter,
loc); loc);
Value gemmOut; Value gemmB = buildPackedWeight(wDenseAttr,
if (effectiveMaxParallelPixels == 1) {
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
gemmOut = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2col,
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value packedA = createPackedIm2colRows(
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
Value packedB = buildPackedWeight(wDenseAttr,
wTrans, wTrans,
wType, wType,
numChannelsIn, numChannelsIn,
@@ -570,24 +574,32 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
effectiveMaxParallelPixels, effectiveMaxParallelPixels,
rewriter, rewriter,
loc); loc);
Value packedC = buildPackedBias( Value gemmC = buildPackedBias(
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedOut = ONNXGemmOp::create(rewriter,
Value gemmRows = ONNXGemmOp::create(rewriter,
loc, loc,
packedOutType, gemmOutputRowsType,
packedA, gemmInputRows,
packedB, gemmB,
packedC, gemmC,
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false)) rewriter.getBoolAttr(false))
.getY(); .getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
}
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc)); rewriter.replaceOp(convOp,
createCollectedConvOutput(ValueRange {gemmRows},
convOp.getType(),
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc));
return success(); return success();
} }

View File

@@ -5,7 +5,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,13 +16,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) { static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>()) if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue()); return dyn_cast<DenseElementsAttr>(constantOp.getValue());

View File

@@ -1,16 +1,16 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -65,6 +65,66 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
RankedTensorType matrixType,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numRows = matrixType.getDimSize(0);
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
auto buildRowSlices = [&](Value matrixArg) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
};
auto cloneBatchInputChainIntoSliceCompute =
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
Value transformedMatrix = input;
if (!chainOps.empty()) {
IRMapping mapper;
mapper.map(rootValue, input);
for (Operation* chainOp : chainOps)
rewriter.clone(*chainOp, mapper);
transformedMatrix = cast<Value>(mapper.lookup(matrix));
}
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
});
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
return rowSlices;
};
SmallVector<Operation*> chainOps;
Value rootValue = matrix;
while (Operation* definingOp = rootValue.getDefiningOp()) {
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
}
if (definingOp->getNumOperands() != 1)
break;
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
break;
chainOps.push_back(definingOp);
rootValue = definingOp->getOperand(0);
}
return buildRowSlices(matrix);
}
} // namespace } // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
@@ -75,13 +135,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB(); Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC(); Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp()); bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType()); auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType()); auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0); const int64_t numOutRows = aType.getDimSize(0);
@@ -114,7 +184,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
}); });
cType = expandedType; cType = expandedType;
} }
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows; cHasNumOutRows = cType.getDimSize(0) == numOutRows;
} }
@@ -138,8 +215,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
} }
else else if (!isVectorShape(getTensorShape(c))) {
assert("C should be a vector" && isVectorShape(getTensorShape(c))); gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure();
}
} }
auto gemvOp = ONNXGemmOp::create(rewriter, auto gemvOp = ONNXGemmOp::create(rewriter,
@@ -156,8 +235,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs); spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
}); });
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -198,11 +276,28 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
}); });
cType = expandedType; cType = expandedType;
} }
assert("Only support rank 2 tensor for C" && cType.getRank() == 2); if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
} }
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() if (!aType.hasStaticShape()) {
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv // Not a gemv
@@ -281,19 +376,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute( auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back( vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
}); });
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp.getResult(0)); partialResults.push_back(computeOp->getResult(0));
} }
if (hasC) { if (hasC) {
@@ -313,8 +414,129 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto concatComputeOp = auto concatComputeOp =
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); });
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1)
return failure();
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
}
(void) bType;
Value sharedBias;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
auto cType = cast<RankedTensorType>(c.getType());
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
cType = cast<RankedTensorType>(c.getType());
}
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure();
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
sharedBias = c;
}
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange(resultTypes),
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights),
ValueRange(aSlices));
Block* body = rewriter.createBlock(
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
}); });
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -322,6 +544,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
} }
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
patterns.insert<GemmToManyGemv>(ctx); patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx); patterns.insert<GemvToSpatialCompute>(ctx);
} }

View File

@@ -2,9 +2,10 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -14,7 +15,108 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> { static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static bool isConstantLikeOperand(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -24,80 +126,113 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape()) || !outType.hasStaticShape())
return failure(); return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3) if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure(); return failure();
const int64_t batch = rhsType.getDimSize(0); const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t k = rhsType.getDimSize(1); const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t n = rhsType.getDimSize(2); const int64_t batch = std::max(lhsBatch, rhsBatch);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|| outType.getDimSize(2) != n)
return failure(); return failure();
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
if (k != rhsK)
return failure();
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
return failure();
}
Location loc = matmulOp.getLoc(); Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType()); bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed = Value lhs = matmulOp.getA();
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0})); Value rhs = matmulOp.getB();
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows; if (outType.getRank() == 2) {
gemmRows.reserve(batch * n); Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
for (int64_t colIdx = 0; colIdx < n; colIdx++) { Value gemmResult = ONNXGemmOp::create(rewriter,
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
rhsRowType, gemmType,
rhsSlice, lhsMatrix,
SmallVector<ReassociationIndices> { rhsMatrix,
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none, none,
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false)); rewriter.getBoolAttr(false))
gemmRows.push_back(gemmOp.getY()); .getY();
} if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
rewriter.replaceOp(matmulOp, gemmResult);
return success();
} }
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) { SmallVector<Value> batchResults;
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs); batchResults.reserve(batch);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
}); Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmOut = concatComputeOp.getResult(0); Value gemmResult = ONNXGemmOp::create(rewriter,
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
gemmExpandedType, gemmType,
gemmOut, lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(
rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
gemmResult,
rewriter.getI64ArrayAttr({1, 0}));
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
gemmResult,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0, 1}, {0, 1},
{2} {2}
}); }));
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); }
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();
} }
@@ -106,7 +241,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace } // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx); patterns.insert<MatMulToGemm>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -5,7 +5,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -100,8 +100,7 @@ static Value buildReduceMeanKeepdims(Value input,
for (Value slice : slices) for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return reducedSlices.size() == 1 ? reducedSlices.front() return createSpatConcat(rewriter, loc, axis, reducedSlices);
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
} }
static Value squeezeReducedAxes(Value keepdimsValue, static Value squeezeReducedAxes(Value keepdimsValue,

View File

@@ -6,13 +6,12 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -31,11 +30,14 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
} }
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) { template <typename PoolOp>
assert(!values.empty() && "Expected at least one value to concatenate."); static FailureOr<Value>
if (values.size() == 1) concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
return values.front(); if (values.empty()) {
return tensor::ConcatOp::create(rewriter, loc, axis, values); poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
return failure();
}
return createSpatConcat(rewriter, loc, axis, values);
} }
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
@@ -53,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
} }
template <typename ReduceOp> template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) { static FailureOr<Value>
assert(!windowValues.empty() && "Expected at least one pool window value."); reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
if (windowValues.empty()) {
op->emitOpError("pool window resolved to zero valid elements");
return failure();
}
Value reduced = windowValues.front(); Value reduced = windowValues.front();
for (Value value : windowValues.drop_front()) for (Value value : windowValues.drop_front())
@@ -62,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
return reduced; return reduced;
} }
static Value static FailureOr<Value>
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) { scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive."); if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
if (divisor == 1) if (divisor == 1)
return reducedWindow; return reducedWindow;
@@ -72,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
double scale = 1.0 / static_cast<double>(divisor); double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor); return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
} }
template <typename PoolOp> template <typename PoolOp>
@@ -211,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (windowValues.empty()) if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues); auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
if (failed(reducedWindow))
return failure();
Value reducedWindowValue = *reducedWindow;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1; const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor = const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size()); countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
if (failed(scaledWindow))
return failure();
reducedWindowValue = *scaledWindow;
} }
outputChannelTiles.push_back(reducedWindow); outputChannelTiles.push_back(reducedWindowValue);
} }
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
if (failed(rowPixel))
return failure();
rowPixels.push_back(*rowPixel);
} }
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
if (failed(row))
return failure();
rows.push_back(*row);
} }
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
if (failed(batchResult))
return failure();
batchResults.push_back(*batchResult);
} }
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); if (failed(pooledOutput))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))

View File

@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -47,8 +47,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
for (Value slice : slices) for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return rebuiltSlices.size() == 1 ? rebuiltSlices.front() return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
} }
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {

View File

@@ -1,7 +1,8 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -17,7 +18,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs(); auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis(); int64_t axis = adaptor.getAxis();
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs); rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success(); return success();
} }

View File

@@ -5,7 +5,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
} }
if (slices.empty()) if (slices.empty())
return {}; return {};
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); return createSpatConcat(rewriter, loc, axis, slices);
} }
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure(); return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
} }
result = rows.size() == 1 result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
? rows.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
} }
else { else {
return failure(); return failure();

View File

@@ -5,7 +5,7 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
} }
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); return createSpatConcat(rewriter, loc, axis, slices);
} }
struct Resize : OpConversionPattern<ONNXResizeOp> { struct Resize : OpConversionPattern<ONNXResizeOp> {

View File

@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -23,7 +23,10 @@ static Value extractSliceAt(
sizes.push_back(rewriter.getIndexAttr(dim)); sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset); offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size); sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
} }
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
@@ -49,12 +52,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
int64_t sliceSize = resultType.getShape()[axis]; int64_t sliceSize = resultType.getShape()[axis];
auto computeOp = outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
});
outputs.push_back(computeOp.getResult(0));
offset += sliceSize; offset += sliceSize;
} }

View File

@@ -42,15 +42,15 @@ private:
raw_ostream& os; raw_ostream& os;
/** /**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including: * Draws the subgraph for a given spatial::SpatCompute, including:
* 1. Input nodes (block arguments) * 1. Input nodes (block arguments)
* 2. Operations * 2. Operations
* 3. Edges between yield (output) and its users * 3. Edges between yield (output) and its users
* *
* @param op The spatial::SpatWeightedCompute to draw the subgraph for. * @param op The spatial::SpatCompute to draw the subgraph for.
* @param computeNum The number of the compute operation. * @param computeNum The number of the compute operation.
*/ */
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n" os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
<< "\t\tstyle=filled;\n" << "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n"; << "\t\tcolor=lightblue;\n";
@@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() {
// 1. Print their subgraph // 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs // 2. Print the edges from its inputs to its outputs
for (Operation& op : func.getOps()) { for (Operation& op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++); drawComputeOpSubgraph(computeOp, computeNum++);
} }
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) { else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {

View File

@@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
Common.cpp Common.cpp
Patterns.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -7,23 +7,12 @@
#include <cstddef> #include <cstddef>
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/* /*
EXAMPLE RUN: EXAMPLE RUN:
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
} }
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) { Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers(); auto users = value.getUsers();
@@ -127,6 +85,16 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; }); return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
} }
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) { mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1); assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0); mlir::Value result = operation->getResult(0);
@@ -134,8 +102,9 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType)); assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation); SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands = auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; }); return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
});
auto bestOperand = validOperands.begin(); auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end()) if (bestOperand != validOperands.end())

View File

@@ -2,16 +2,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
/** /**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and * \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input. * its static tensor input.
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T> template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) { size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end()); return std::distance(range.begin(), range.end());

View File

@@ -0,0 +1,385 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
Location loc = extractSliceOp.getLoc();
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
if (spatCompute.getInputs().empty())
return failure();
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
return failure();
}
}
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
{
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
rewriter.eraseOp(extractSliceOp);
return success();
}
};
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
constantOp.getValueAttr(),
rewriter.getUnitAttr(),
{});
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
}
}
}
}
auto parent = constantOp->getParentOp();
rewriter.eraseOp(constantOp);
return success();
}
};
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
if (funcOp.getArguments().empty())
return failure();
if (llvm::all_of(funcOp.getArguments(),
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
return failure();
Location loc = funcOp.getLoc();
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
if (arg.getUses().empty())
continue;
rewriter.setInsertionPoint(funcOp.getOperation());
assert(isa<mlir::RankedTensorType>(arg.getType()));
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
rewriter.setInsertionPoint(argUser);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(argUser);
argUses.set(toTensor);
rewriter.finalizeOpModification(argUser);
}
}
}
return success();
}
};
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,10 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}

View File

@@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def HasSpatialChannelSourceCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
"spatial channel has precomputed source core id">;
def HasSpatialChannelTargetCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
"spatial channel has precomputed target core id">;
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
def onnxToPimTranspose : Pat< def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms), (ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms, (PimTransposeOp $data, $perms,
@@ -80,18 +69,4 @@ def spatToPimVSoftmax : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input),
(PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
[(HasSpatialChannelTargetCoreIdAttr $channel)]
>;
def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel),
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
[(HasSpatialChannelSourceCoreIdAttr $channel)]
>;
#endif // SPATIAL_TO_PIM #endif // SPATIAL_TO_PIM

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ def PimTensor :
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock]> { def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
let summary = "Execute a block on a PIM core"; let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
@@ -39,6 +39,22 @@ def PimCoreOp : PimOp<"core", [SingleBlock]> {
}]; }];
} }
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Execute equivalent batched core bodies";
let regions = (region SizedRegion<1>:$body);
let arguments = (ins
I32Attr:$laneCount,
Variadic<PimTensor>:$weights,
Variadic<PimTensor>:$inputs
);
let assemblyFormat = [{
`lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)`
}];
}
def PimHaltOp : PimOp<"halt", [Terminator]> { def PimHaltOp : PimOp<"halt", [Terminator]> {
let summary = "Halt execution of the core"; let summary = "Halt execution of the core";
@@ -65,6 +81,20 @@ def PimSendOp : PimOp<"send", []> {
}]; }];
} }
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
DenseI32ArrayAttr:$targetCoreIds
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
}];
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core"; let summary = "Receive a tensor from another core";
@@ -89,6 +119,30 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
}];
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory"; let summary = "Copy a memory region from host memory into device memory";
@@ -115,6 +169,32 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
let arguments = (ins
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceTargetMutable();
}
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
}];
}
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory"; let summary = "Copy a memory region from device memory into host memory";

View File

@@ -1,6 +1,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "OpBufferizationInterfaces.hpp" #include "OpBufferizationInterfaces.hpp"
@@ -65,6 +66,32 @@ struct MemCopyHostToDevOpInterface
} }
}; };
struct MemCopyHostToDevBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
if (failed(deviceTargetOpt))
return failure();
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
if (failed(hostSourceOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
memCopyHostToDevOp,
deviceTargetOpt->getType(),
*deviceTargetOpt,
*hostSourceOpt,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -122,6 +149,127 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
} }
}; };
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveBatchOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return false;
}
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return {};
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
return false;
}
FailureOr<BufferLikeType>
getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return failure();
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
return memRefType;
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op);
SmallVector<Value> weights;
SmallVector<Value> inputs;
weights.reserve(coreBatchOp.getWeights().size());
inputs.reserve(coreBatchOp.getInputs().size());
for (Value weight : coreBatchOp.getWeights()) {
if (isa<TensorType>(weight.getType())) {
auto weightOpt = getBuffer(rewriter, weight, options, state);
if (failed(weightOpt))
return failure();
weights.push_back(*weightOpt);
}
else {
weights.push_back(weight);
}
}
for (Value input : coreBatchOp.getInputs()) {
if (isa<TensorType>(input.getType())) {
auto inputOpt = getBuffer(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(*inputOpt);
}
else {
inputs.push_back(input);
}
}
rewriter.setInsertionPoint(coreBatchOp);
auto newOp = PimCoreBatchOp::create(
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
for (Block& block : newOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
return failure();
rewriter.eraseOp(coreBatchOp);
return success();
}
};
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> { struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -178,8 +326,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>( replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success(); return success();
} }
}; };
@@ -203,8 +353,10 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimMVMOp>( replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success(); return success();
} }
}; };
@@ -283,8 +435,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) { void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx); PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpInterface>(*ctx); PimVMMOp::attachInterface<VMMOpInterface>(*ctx);

View File

@@ -3,12 +3,17 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Threading.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -40,15 +45,45 @@ private:
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
// Refactor this into a function
{
auto funcOp = getPimEntryFunc(moduleOp);
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
MLIRContext* ctx = moduleOp.getContext();
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
// Again, allocate state LOCALLY per thread/function
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
return failure();
}
return success();
});
if (failed(result)) {
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
signalPassFailure();
}
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
});
// One-Shot-Bufferization // One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options; bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true; options.allowUnknownOps = true;
bufferization::BufferizationState state; bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure(); signalPassFailure();
} }
}
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
@@ -57,7 +92,18 @@ void PimBufferizationPass::runOnOperation() {
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
populateWithGenerated(patterns); populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { // Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
bool hasFailed = false;
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return WalkResult::advance();
if (failed(applyPartialConversion(op, target, frozenPatterns)))
hasFailed = true;
return WalkResult::skip();
});
if (hasFailed) {
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -93,16 +139,21 @@ void PimBufferizationPass::runOnOperation() {
} }
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { auto markWeights = [&](Operation* op) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty() walkPimMvmVmmWeightUses(op, [&](OpOperand& weightUse) {
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }); Value weight = weightUse.get();
if (isAlwaysWeight) { auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return;
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp); markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp); markWeightAlways(globalMemrefOp);
}
}); });
};
funcOp.walk([&](PimCoreOp coreOp) { markWeights(coreOp); });
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
} }
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); } std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }

View File

@@ -2,7 +2,11 @@ add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td) add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps add_pim_library(SpatialOps
Channels.cpp
SpatialOps.cpp SpatialOps.cpp
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp

View File

@@ -0,0 +1,120 @@
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
using namespace mlir;
namespace onnx_mlir::spatial {
namespace {
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive)
return failure();
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
return failure();
}
return success();
}
} // namespace
Channels::Channels(func::FuncOp funcOp) {
if (!funcOp)
return;
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
}
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].send = sendOp;
}
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].receive = receiveOp;
}
void Channels::eraseSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
auto it = endpoints.find(channelId);
if (it == endpoints.end())
return;
it->second.send = {};
if (!it->second.receive)
endpoints.erase(it);
}
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
auto it = endpoints.find(channelId);
if (it == endpoints.end())
return;
it->second.receive = {};
if (!it->second.send)
endpoints.erase(it);
}
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
auto it = endpoints.find(id);
if (it == endpoints.end())
return failure();
return it->second;
}
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
auto endpointsOr = lookup(getChannelId(sendOp));
if (failed(endpointsOr) || !endpointsOr->receive)
return failure();
return endpointsOr->receive;
}
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
auto endpointsOr = lookup(getChannelId(receiveOp));
if (failed(endpointsOr) || !endpointsOr->send)
return failure();
return endpointsOr->send;
}
LogicalResult Channels::verify() const {
for (const auto& [channelId, pair] : endpoints) {
if (!pair.send || !pair.receive) {
if (pair.send) {
auto sendOp = pair.send;
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
}
else if (pair.receive) {
auto receiveOp = pair.receive;
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
}
return failure();
}
if (failed(verifyEndpointPair(pair)))
return failure();
}
return success();
}
} // namespace onnx_mlir::spatial

View File

@@ -0,0 +1,43 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir::spatial {
struct ChannelEndpoints {
SpatChannelSendOp send;
SpatChannelReceiveOp receive;
};
class Channels {
public:
using ChannelId = int64_t;
explicit Channels(mlir::func::FuncOp funcOp);
ChannelId allocate();
void insertSend(SpatChannelSendOp sendOp);
void insertReceive(SpatChannelReceiveOp receiveOp);
void eraseSend(SpatChannelSendOp sendOp);
void eraseReceive(SpatChannelReceiveOp receiveOp);
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
mlir::LogicalResult verify() const;
private:
ChannelId nextChannelId = 0;
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
};
} // namespace onnx_mlir::spatial

View File

@@ -9,7 +9,6 @@ def SpatialDialect : Dialect {
let name = "spat"; let name = "spat";
let summary = "Dialect designed for deep learning computation in a spatial architecture"; let summary = "Dialect designed for deep learning computation in a spatial architecture";
let cppNamespace = "::onnx_mlir::spatial"; let cppNamespace = "::onnx_mlir::spatial";
let useDefaultTypePrinterParser = 1;
} }
class SpatOp<string mnemonic, list<Trait> traits = []> : class SpatOp<string mnemonic, list<Trait> traits = []> :
@@ -19,20 +18,11 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
def SpatTensor : def SpatTensor :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<SpatialDialect, name, traits> {
let mnemonic = typeMnemonic;
}
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
let summary = "Virtual channel type";
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights"; let summary = "Compute region with attached constant weights";
let arguments = (ins let arguments = (ins
@@ -48,10 +38,27 @@ def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegmen
let hasVerifier = 1; let hasVerifier = 1;
let hasFolder = 1; let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
let assemblyFormat = [{ def SpatComputeBatch : SpatOp<"compute_batch",
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body [SingleBlock, AttrSizedOperandSegments]> {
}]; let summary = "Compressed batch of independent equivalent compute lanes";
let arguments = (ins
I32Attr:$laneCount,
Variadic<SpatTensor>:$weights,
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
} }
def SpatYieldOp : SpatOp<"yield", [Terminator]> { def SpatYieldOp : SpatOp<"yield", [Terminator]> {
@@ -61,51 +68,66 @@ def SpatYieldOp : SpatOp<"yield", [Terminator]> {
Variadic<SpatTensor>:$outputs Variadic<SpatTensor>:$outputs
); );
let assemblyFormat = [{ let hasCustomAssemblyFormat = 1;
$outputs attr-dict `:` type($outputs) }
}];
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatConcatOp : SpatOp<"concat", []> {
let summary = "Concatenate tensors with compact Spatial operand syntax";
let arguments = (ins
I64Attr:$axis,
Variadic<SpatTensor>:$inputs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Communication // Communication
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatChannelNewOp : SpatOp<"channel_new", []> {
let summary = "Create a new virtual channel";
let results = (outs
SpatChannelType:$channel
);
let builders = [
OpBuilder<(ins ), [{
$_state.addTypes(SpatChannelType());
}]>
];
let assemblyFormat = [{
attr-dict
}];
}
def SpatChannelSendOp : SpatOp<"channel_send", []> { def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a channel"; let summary = "Send a tensor through a logical channel";
let arguments = (ins let arguments = (ins
SpatChannelType:$channel, I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId,
SpatTensor:$input SpatTensor:$input
); );
let assemblyFormat = [{ let assemblyFormat = [{
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` $input attr-dict `:` type($input)
}]; }];
} }
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a channel"; let summary = "Receive a tensor from a logical channel";
let arguments = (ins let arguments = (ins
SpatChannelType:$channel I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId
); );
let results = (outs let results = (outs
@@ -113,37 +135,70 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
); );
let assemblyFormat = [{ let assemblyFormat = [{
$channel attr-dict `:` `(` type($channel) `->` type($output) `)` attr-dict `:` type($output)
}]; }];
} }
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> { def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
let summary = "Broadcast a tensor through a shared channel buffer"; let summary = "Send multiple tensors through logical channels";
let arguments = (ins let arguments = (ins
SpatChannelType:$channel, DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
let summary = "Receive multiple tensors from logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
SpatTensor:$input SpatTensor:$input
); );
let assemblyFormat = [{ let hasVerifier = 1;
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` let hasCustomAssemblyFormat = 1;
}];
} }
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> { def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let summary = "Receive a tensor from a shared channel buffer"; let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
SpatChannelType:$channel DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
); );
let results = (outs let results = (outs
SpatTensor:$output SpatTensor:$output
); );
let assemblyFormat = [{ let hasVerifier = 1;
$channel attr-dict `:` `(` type($channel) `->` type($output) `)` let hasCustomAssemblyFormat = 1;
}];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@@ -1,31 +1,3 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -45,239 +17,6 @@ void SpatialDialect::initialize() {
>(); >();
} }
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 2
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
// Verify that the matrix shape is (N, M)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
// Verify that the vector shape is (M, 1)
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
// Verify that the output shape is (N, 1)
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 4
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
// Verify that the matrix shape is (N, M, 1, 1)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
// Verify that the vector shape is (1, M, 1, 1)
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
// Verify that the output shape is (1, N, 1, 1)
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
}
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
*/
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
else if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
else
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
1. vector: (1, N); matrix: (N, M); output: (1, M)
*/
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (N, 1)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (M, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) {
return emitError("ComputeOp must have same number of results as yieldOp "
"operands");
}
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
// Same type and compatible shape
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
}
// Same encoding
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as "
"yieldOp operand");
}
}
else {
return emitError("ComputeOp output has an encoding while yieldOp "
"operand does not have one");
}
}
else {
// If result does not have an encoding, yield shouldn't either
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if "
"yieldOp operand has one");
}
}
}
}
// Check that each block argument is used
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
if (!yieldOp)
return failure();
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
continue;
}
}
results.push_back(yieldedValue);
}
return success();
}
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,912 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/LogicalResult.h"
#include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
enum class ListDelimiter {
Square,
Paren
};
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
}
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
template <typename IntT>
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
while (true) {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0)
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
if (succeeded(parser.parseOptionalRSquare()))
break;
if (parser.parseComma())
return failure();
}
return success();
}
template <typename RangeT, typename PrintEntryFn>
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename IntT>
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printer << "[";
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
auto findEqualRunEnd = [&](size_t start) {
size_t end = start + 1;
while (end < values.size() && values[end] == values[start])
++end;
return end;
};
size_t firstRunEnd = findEqualRunEnd(index);
size_t repeatCount = firstRunEnd - index;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[index];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
printer << values[index] << " to " << lastValue;
if (step != 1)
printer << " by " << step;
if (repeatCount > 1)
printer << " x" << repeatCount;
index = progressionEnd;
continue;
}
if (repeatCount > 1) {
printer << values[index] << " x" << repeatCount;
index = firstRunEnd;
continue;
}
printer << values[index];
index = firstRunEnd;
}
printer << "]";
}
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
printCloseDelimiter(printer, delimiter);
}
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
printCloseDelimiter(printer, delimiter);
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
}
return success();
}
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
printer << " channels ";
printCompressedIntegerList(printer, channelIds);
printer << " from ";
printCompressedIntegerList(printer, sourceCoreIds);
printer << " to ";
printCompressedIntegerList(printer, targetCoreIds);
}
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
return parser.getBuilder().getDenseI64ArrayAttr(values);
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static void buildImplicitRegionArgs(OpAsmParser& parser,
ArrayRef<Type> inputTypes,
SmallVectorImpl<std::string>& generatedNames,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
generatedNames.reserve(inputTypes.size());
arguments.reserve(inputTypes.size());
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
generatedNames.push_back("arg" + std::to_string(index + 1));
OpAsmParser::Argument arg;
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
arg.type = inputType;
arguments.push_back(arg);
}
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parser.parseType(inputType) || parser.parseArrow()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (parser.resolveOperand(input, inputType, result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void SpatConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis();
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t axis = 0;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
Type outputType;
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (result.attributes.get("axis"))
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " core_id " << coreIdAttr.getInt();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t coreId = 0;
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"core_id cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " args = ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
printer << " core_ids ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int32_t laneCount = 0;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<std::string> generatedArgNames;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
return failure();
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (succeeded(parser.parseOptionalKeyword("args"))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
return failure();
}
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
return failure();
}
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutput().getType());
}
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,35 @@
#include "mlir/IR/Block.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
if (!yieldOp)
return failure();
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
continue;
}
}
results.push_back(yieldedValue);
}
return success();
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,438 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error.
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto batchOp = op->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return failure();
return batchOp.getLaneCount();
}
static LogicalResult verifyManyChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
size_t valueCount) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.size() != valueCount)
return op->emitError("channel metadata length must match the number of values");
return success();
}
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
Type firstType = types.front();
for (Type type : types.drop_front())
if (type != firstType)
return op->emitError() << kind << " values must all have the same type";
return success();
}
static LogicalResult verifyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return op->emitError("body must terminate with spat.yield");
if (outputTypes.empty()) {
if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results");
}
else {
if (yieldOp.getNumOperands() != 1)
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
}
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
}
return success();
}
} // namespace
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (1, N)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (1, M)");
return success();
}
LogicalResult SpatVAddOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatExtractRowsOp::verify() {
auto inputType = dyn_cast<ShapedType>(getInput().getType());
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
return emitError("input must be a rank-2 shaped type");
int64_t numRows = inputType.getShape()[0];
int64_t numCols = inputType.getShape()[1];
Type elementType = inputType.getElementType();
if (numRows >= 0 && static_cast<int64_t>(getNumResults()) != numRows)
return emitError("number of outputs must match the number of input rows");
for (Type output : getResultTypes()) {
auto outputType = dyn_cast<ShapedType>(output);
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
return emitError("outputs must all be rank-2 shaped types");
if (outputType.getElementType() != elementType)
return emitError("output element types must match input element type");
auto outputShape = outputType.getShape();
if (outputShape[0] != 1)
return emitError("each output must have exactly one row");
if (numCols >= 0 && outputShape[1] != numCols)
return emitError("output column count must match input column count");
}
return success();
}
LogicalResult SpatConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!outputType || !outputType.hasRank())
return emitError("output must be a ranked shaped type");
int64_t axis = getAxis();
int64_t rank = outputType.getRank();
if (axis < 0 || axis >= rank)
return emitError("axis must be within the output rank");
int64_t concatenatedDimSize = 0;
bool concatenatedDimDynamic = false;
Type outputElementType = outputType.getElementType();
for (Value input : getInputs()) {
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType || !inputType.hasRank())
return emitError("inputs must be ranked shaped types");
if (inputType.getRank() != rank)
return emitError("all inputs must have the same rank as the output");
if (inputType.getElementType() != outputElementType)
return emitError("all inputs must have the same element type as the output");
for (int64_t dim = 0; dim < rank; ++dim) {
if (dim == axis)
continue;
int64_t inputDim = inputType.getDimSize(dim);
int64_t outputDim = outputType.getDimSize(dim);
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
return emitError("non-concatenated dimensions must match the output shape");
}
int64_t inputConcatDim = inputType.getDimSize(axis);
if (ShapedType::isDynamic(inputConcatDim)) {
concatenatedDimDynamic = true;
continue;
}
concatenatedDimSize += inputConcatDim;
}
int64_t outputConcatDim = outputType.getDimSize(axis);
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
return emitError("output concatenated dimension must equal the sum of input sizes");
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size())
return emitError("ComputeOp must have same number of results as yieldOp operands");
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
return emitError("ComputeOp output must be of the same type as yieldOp operand");
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
}
else {
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
}
}
else if (dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
}
}
}
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
LogicalResult SpatChannelSendManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
}
LogicalResult SpatChannelReceiveManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr)
return emitError("compute_batch core_id attribute must be a dense i32 array");
if (coreIdsAttr.size() != laneCountSz)
return emitError("compute_batch core_id array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch core_id values must be positive");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch core_id values must be distinct");
}
Block& block = getBody().front();
if (getInputs().empty()) {
if (block.getNumArguments() != 0)
return emitError("compute_batch body must have no block arguments when there are no inputs");
}
else {
if (block.getNumArguments() != 1)
return emitError("compute_batch body must have exactly one block argument");
if (block.getArgument(0).getType() != getInputs()[0].getType())
return emitError("body block argument type must match input type");
}
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -3,13 +3,23 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <iterator> #include <iterator>
#include <numeric>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
#include "Graph.hpp" #include "Graph.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Support/TypeUtilities.hpp" #include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -17,46 +27,729 @@ namespace spatial {
using namespace mlir; using namespace mlir;
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) { namespace {
using SpatCompute = onnx_mlir::spatial::SpatCompute;
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
}
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : body)
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0;
for (auto& block : body)
for (auto& op : block)
if (isa<SpatWeightedVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
}
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
inputs.push_back(batch.getInputs()[lane]);
return inputs;
}
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
Operation* op = value.getDefiningOp();
if (!op)
return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto spatCompute = dyn_cast<SpatCompute>(op))
return ComputeInstance {spatCompute.getOperation(), 0, 1};
if (auto batch = dyn_cast<SpatComputeBatch>(op))
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
return std::nullopt;
}
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
SmallVector<ComputeInstance> instances;
for (Region& region : entryOp->getRegions()) {
for (Block& block : region) {
for (Operation& op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
instances.push_back({spatCompute.getOperation(), 0, 1});
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
}
}
}
}
return instances;
}
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(computeInstances.size());
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
VirtualNode node;
node.originalComputeIndices.push_back(index);
node.weight = getComputeInstanceWeight(computeInstance);
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
graph.nodes.push_back(std::move(node));
}
graph.edges = aggregateEdges(edges);
return graph;
}
TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode& node = graph.nodes[nodeIndex];
if (!node.originalComputeIndices.empty())
return node.originalComputeIndices.front();
return nodeIndex;
};
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
size_t lhsKey = getVirtualNodeOrderKey(lhs);
size_t rhsKey = getVirtualNodeOrderKey(rhs);
if (lhsKey != rhsKey)
return lhsKey > rhsKey;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push(i);
while (!readyNodes.empty()) {
size_t current = readyNodes.top();
readyNodes.pop();
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto& neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
windowNodeOrderKeys.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
}
GraphDCP windowGraph(
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
if (coresCount.getValue() > 0)
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph,
std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
memberNode.originalComputeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
}
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid) {
virtualNodeOrder = std::move(timing.topologicalOrder);
}
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalComputeToCpu[originalIndex] = cpu;
}
result.dominanceOrderCompute.reserve(computeInstances.size());
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
size_t cpu = originalComputeToCpu[originalIndex];
result.dominanceOrderCompute.push_back(computeInstance);
result.computeToCpuMap[computeInstance] = cpu;
result.cpuToLastComputeMap[cpu] = computeInstance;
}
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
return result;
}
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (const auto& task : scheduledTasks)
result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu;
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
}
return result;
}
DCPAnalysisResult
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
SmallVector<Weight> nodeWeights;
SmallVector<CrossbarUsage> nodeCrossbarUsage;
SmallVector<int64_t> nodeOrderKeys;
nodeWeights.reserve(computeInstances.size());
nodeCrossbarUsage.reserve(computeInstances.size());
nodeOrderKeys.reserve(computeInstances.size());
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
nodeWeights.push_back(getComputeInstanceWeight(instance));
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
nodeOrderKeys.push_back(static_cast<int64_t>(index));
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
graphDCP.runDcp();
return buildResultFromScheduledGraph(graphDCP, computeInstances);
}
} // namespace
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op) if (!op)
return {}; return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) { while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp(); op = extract.getSource().getDefiningOp();
if (!op) if (!op)
return {}; return {};
} }
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op)) if (auto res = dyn_cast<SpatCompute>(op))
return res; return res;
return {}; return {};
} }
DCPAnalysisResult DCPAnalysis::run() { DCPAnalysisResult DCPAnalysis::run() {
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes; SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
llvm::SmallVector<IndexedEdge, 10> edges; SmallVector<IndexedEdge, 10> edges;
for (auto& region : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
spatWeightedComputes.push_back(spatWeightedCompute);
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
for (Value input : spatWeightedCompute.getInputs()) { instanceToIndex.reserve(computeInstances.size());
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) { for (auto [index, instance] : llvm::enumerate(computeInstances))
auto producerIt = llvm::find(spatWeightedComputes, producerCompute); instanceToIndex[instance] = index;
assert(producerIt != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt); for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
ResultRange outputs = producerCompute.getResults(); for (Value input : getComputeInstanceInputs(computeInstance)) {
int64_t totalSize = 0; if (auto producerInstance = getOriginalComputeInstance(input)) {
for (auto output : outputs) { auto producerIt = instanceToIndex.find(*producerInstance);
ShapedType resultType = cast<ShapedType>(output.getType()); assert(producerIt != instanceToIndex.end());
totalSize += getSizeInBytes(resultType); auto indexStartEdge = producerIt->second;
} edges.push_back({static_cast<int64_t>(indexStartEdge),
edges.push_back({indexStartEdge, indexEndEdge, totalSize}); static_cast<int64_t>(indexEndEdge),
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
} }
} }
} }
GraphDCP graphDCP(spatWeightedComputes, edges);
graphDCP.setContext(entryOp->getContext()); if (dcpCriticalWindowSize.getValue() == 0)
graphDCP.runDcp(); return runLegacyDcp(computeInstances, edges, entryOp->getContext());
return graphDCP.getResult();
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
size_t iteration = 0;
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) {
if (oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
}
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
}
SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
} }
} // namespace spatial } // namespace spatial

View File

@@ -5,15 +5,28 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <vector> #include <vector>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
// A scheduling identity that covers both spat.compute and scheduled shards of
// spat.compute_batch.
struct ComputeInstance {
mlir::Operation* op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
struct DCPAnalysisResult { struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute; std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap; llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu; llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap; llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
}; };
namespace onnx_mlir { namespace onnx_mlir {
@@ -34,3 +47,21 @@ public:
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const ComputeInstance& v) {
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
}
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
return a == b;
}
};
} // namespace llvm

View File

@@ -6,7 +6,7 @@
// consumer land on different CPUs. // consumer land on different CPUs.
// //
// Output: an assignment of every task to a CPU and an order within that CPU, // Output: an assignment of every task to a CPU and an order within that CPU,
// aiming to minimise the overall critical-path length (DCPL). // aiming to minimize the overall critical-path length (DCPL).
// //
// Every task keeps two timing estimates: // Every task keeps two timing estimates:
// AEST - earliest start time, driven by parent completions + transfers. // AEST - earliest start time, driven by parent completions + transfers.
@@ -16,9 +16,9 @@
// Main loop (runDcp): // Main loop (runDcp):
// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG. // 1. Build a topological order and seed AEST/ALST from the unscheduled DAG.
// 2. While there are ready tasks (all dependency parents scheduled): // 2. While there are ready tasks (all dependency parents scheduled):
// a. Pick the candidate with tightest slack (earliest AEST breaks ties). // a. Pick the candidate with the tightest slack (earliest AEST breaks ties).
// b. selectProcessor() tries every candidate CPU and picks the one that // b. selectProcessor() tries every candidate CPU and picks the one that
// minimises a composite cost (own slot + smallest unscheduled child). // minimizes a composite cost (own slot + the smallest unscheduled child).
// c. Commit the placement and refresh AEST/ALST. // c. Commit the placement and refresh AEST/ALST.
// d. Release any child whose dependency parents are now all scheduled. // d. Release any child whose dependency parents are now all scheduled.
// //
@@ -38,12 +38,14 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <chrono> #include <chrono>
#include <cstdint>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <queue>
#include <vector> #include <vector>
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
@@ -61,6 +63,7 @@ namespace {
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set. // Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
struct SelectTimers { struct SelectTimers {
double findSlot = 0.0; double findSlot = 0.0;
double dedup = 0.0;
double precheck = 0.0; double precheck = 0.0;
double snapshotInsertUpdate = 0.0; double snapshotInsertUpdate = 0.0;
double childSlot = 0.0; double childSlot = 0.0;
@@ -71,9 +74,19 @@ struct SelectTimers {
long tasksProcessed = 0; long tasksProcessed = 0;
void dump(const char* label) const { void dump(const char* label) const {
std::fprintf(stderr, std::fprintf(stderr,
"[selectProfile:%s] tasks=%ld findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n", "[selectProfile:%s] tasks=%ld dedup=%.2fs findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs "
label, tasksProcessed, findSlot, precheck, snapshotInsertUpdate, childSlot, "childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
rollbackRestore, iterations, passedPrecheck, passedDcpl); label,
tasksProcessed,
dedup,
findSlot,
precheck,
snapshotInsertUpdate,
childSlot,
rollbackRestore,
iterations,
passedPrecheck,
passedDcpl);
} }
~SelectTimers() { ~SelectTimers() {
if (std::getenv("DCP_SELECT_PROFILE")) if (std::getenv("DCP_SELECT_PROFILE"))
@@ -84,6 +97,101 @@ static SelectTimers gSelectTimers;
} // namespace } // namespace
#endif #endif
namespace {
uint64_t mixHash(uint64_t seed, uint64_t value) {
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
return seed;
}
uint64_t finishHash(uint64_t seed) {
seed ^= seed >> 33;
seed *= 0xff51afd7ed558ccdULL;
seed ^= seed >> 33;
seed *= 0xc4ceb9fe1a85ec53ULL;
seed ^= seed >> 33;
return seed;
}
uint64_t hashEdgeSignature(uint64_t neighborHash, Weight weight, uint64_t direction) {
uint64_t hash = mixHash(0x84222325cbf29ce4ULL, direction);
hash = mixHash(hash, neighborHash);
hash = mixHash(hash, static_cast<uint64_t>(weight));
return finishHash(hash);
}
struct CpuAestCache {
Time defaultAest = 0;
llvm::SmallDenseMap<CPU, Time, 8> colocatedParentAests;
Time get(CPU cpu) const {
auto it = colocatedParentAests.find(cpu);
if (it == colocatedParentAests.end())
return defaultAest;
return it->second;
}
};
struct CpuTimeMax {
CPU cpu = -1;
Time time = 0;
};
void updateCpuTimeMax(CpuTimeMax& first, CpuTimeMax& second, CPU cpu, Time time) {
if (first.cpu == cpu) {
first.time = std::max(first.time, time);
return;
}
if (second.cpu == cpu) {
second.time = std::max(second.time, time);
if (second.time > first.time)
std::swap(first, second);
return;
}
if (time >= first.time) {
second = first;
first = {cpu, time};
return;
}
if (time > second.time)
second = {cpu, time};
}
CpuAestCache computeCpuAestCache(TaskDCP* task) {
CpuAestCache cache;
llvm::SmallDenseMap<CPU, Time, 8> transferAestByCpu;
llvm::SmallDenseMap<CPU, Time, 8> localAestByCpu;
Time unscheduledTransferAest = 0;
for (const Edge& parentEdge : task->parents) {
Time parentFinish = addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight());
Time transferAest = addOrMax(parentFinish, getTransferCost(parentEdge.first, task));
if (std::optional<CPU> parentCpu = parentEdge.first->getCpu()) {
Time& cpuTransferAest = transferAestByCpu[*parentCpu];
cpuTransferAest = std::max(cpuTransferAest, transferAest);
Time& cpuLocalAest = localAestByCpu[*parentCpu];
cpuLocalAest = std::max(cpuLocalAest, parentFinish);
continue;
}
unscheduledTransferAest = std::max(unscheduledTransferAest, transferAest);
}
CpuTimeMax firstOther {-1, unscheduledTransferAest};
CpuTimeMax secondOther {-1, 0};
for (const auto& entry : transferAestByCpu)
updateCpuTimeMax(firstOther, secondOther, entry.first, entry.second);
cache.defaultAest = firstOther.time;
for (const auto& entry : localAestByCpu) {
CPU cpu = entry.first;
Time bestNonLocalParentAest = firstOther.cpu == cpu ? secondOther.time : firstOther.time;
cache.colocatedParentAests[cpu] = std::max(bestNonLocalParentAest, entry.second);
}
return cache;
}
} // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Edge manipulation // Edge manipulation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -157,6 +265,49 @@ std::vector<TaskDCP*> GraphDCP::getRoots() {
return tmp; return tmp;
} }
void GraphDCP::initTaskStructureHashes() {
taskStructureHashes.resize(nodes.size());
for (auto [index, task] : llvm::enumerate(nodes)) {
uint64_t hash = mixHash(0x7442b1129fd01363ULL, static_cast<uint64_t>(task.getWeight()));
hash = mixHash(hash, static_cast<uint64_t>(task.getCrossbarUsage()));
taskStructureHashes[index] = finishHash(hash);
}
std::vector<uint64_t> nextHashes(nodes.size());
std::vector<uint64_t> edgeHashes;
for (int iteration = 0; iteration < 4; ++iteration) {
for (auto [index, task] : llvm::enumerate(nodes)) {
uint64_t hash = mixHash(0x464dcab27ac82291ULL, taskStructureHashes[index]);
edgeHashes.clear();
edgeHashes.reserve(task.parents.size() + task.children.size());
for (const Edge& parent : task.parents)
if (!parent.isScheduling)
edgeHashes.push_back(
hashEdgeSignature(taskStructureHashes[getNodeIndex(parent.first)], parent.second, /*direction=*/0));
for (const Edge& child : task.children)
if (!child.isScheduling)
edgeHashes.push_back(
hashEdgeSignature(taskStructureHashes[getNodeIndex(child.first)], child.second, /*direction=*/1));
llvm::sort(edgeHashes);
hash = mixHash(hash, static_cast<uint64_t>(edgeHashes.size()));
for (uint64_t edgeHash : edgeHashes)
hash = mixHash(hash, edgeHash);
nextHashes[index] = finishHash(hash);
}
taskStructureHashes.swap(nextHashes);
}
}
// Compact dedup key for CPU `c` vs `candidate`: mixes candidateAest, crossbar
// usage, and the incremental cpu structure hash. No heap allocation.
uint64_t GraphDCP::computeCpuCandidateKey(Time candidateAest, CPU cpu) {
uint64_t hash = mixHash(0xd6e8feb86659fd93ULL, static_cast<uint64_t>(candidateAest));
hash = mixHash(hash, static_cast<uint64_t>(getCpuCrossbarUsage(cpu)));
auto it = cpuStructureHashes.find(cpu);
hash = mixHash(hash, it != cpuStructureHashes.end() ? it->second : 0ULL);
return finishHash(hash);
}
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the // Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
// neighbouring tasks and keeping the global topological order consistent. // neighbouring tasks and keeping the global topological order consistent.
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) { TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
@@ -165,6 +316,7 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
task->setCpu(cpu); task->setCpu(cpu);
task->setWeight(scheduledWeight); task->setWeight(scheduledWeight);
reserveTaskCrossbars(cpu, task); reserveTaskCrossbars(cpu, task);
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
auto& tasksInCpu = getOrCreateCpuTasks(cpu); auto& tasksInCpu = getOrCreateCpuTasks(cpu);
unsigned int numCpuTasks = tasksInCpu.size(); unsigned int numCpuTasks = tasksInCpu.size();
assert(position <= numCpuTasks && "Inserting in a not valid position"); assert(position <= numCpuTasks && "Inserting in a not valid position");
@@ -202,6 +354,7 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) { void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
releaseTaskCrossbars(cpu, task); releaseTaskCrossbars(cpu, task);
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
task->resetCpu(); task->resetCpu();
task->resetWeight(); task->resetWeight();
auto& scheduledTasks = getOrCreateCpuTasks(cpu); auto& scheduledTasks = getOrCreateCpuTasks(cpu);
@@ -272,6 +425,21 @@ bool GraphDCP::wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const
return nextUsage >= getCpuCrossbarCapacity(); return nextUsage >= getCpuCrossbarCapacity();
} }
size_t GraphDCP::crossbarsUsed() const {
CrossbarUsage crossbarEdge = static_cast<CrossbarUsage>(onnx_mlir::crossbarSize.getValue());
CrossbarUsage crossbarArea = crossbarEdge * crossbarEdge;
if (crossbarArea == 0)
return 0;
CrossbarUsage totalArea = 0;
for (const auto& [cpu, usage] : cpuCrossbarUsage)
totalArea = checkedAdd(totalArea, usage);
return static_cast<size_t>(totalArea / crossbarArea);
}
size_t GraphDCP::crossbarsAvailable() const {
return static_cast<size_t>(lastCpu) * onnx_mlir::crossbarCountInCore.getValue();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AEST / ALST computation // AEST / ALST computation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -457,9 +625,9 @@ void GraphDCP::updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<T
for (TaskDCP* descendant : descendantsTopoOrder) for (TaskDCP* descendant : descendantsTopoOrder)
recomputeAest(descendant); recomputeAest(descendant);
const bool oldMaxInvalidated = maxCompletionTask != nullptr const bool oldMaxInvalidated =
&& (maxCompletionTask == task maxCompletionTask != nullptr
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask)); && (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
if (oldMaxInvalidated) { if (oldMaxInvalidated) {
// The pre-update max came from a modified task; its completion has moved // The pre-update max came from a modified task; its completion has moved
// upward, so modifiedMaxCompletion is an upper bound covering it. The // upward, so modifiedMaxCompletion is an upper bound covering it. The
@@ -524,9 +692,9 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
if (!process(descendant)) if (!process(descendant))
return false; return false;
const bool oldMaxInvalidated = maxCompletionTask != nullptr const bool oldMaxInvalidated =
&& (maxCompletionTask == task maxCompletionTask != nullptr
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask)); && (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
if (oldMaxInvalidated) { if (oldMaxInvalidated) {
dcpl = modifiedMaxCompletion; dcpl = modifiedMaxCompletion;
maxCompletion = modifiedMaxCompletion; maxCompletion = modifiedMaxCompletion;
@@ -547,6 +715,109 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
return true; return true;
} }
// Incrementally refreshes ALST after `task` was placed. The set of nodes whose
// ALST is structurally affected by the insertion is exactly
// `relations.ancestors {task}`: the task's outgoing transfer costs to
// same-CPU real children become 0, and new scheduling edges create parent
// relationships between `task` and its same-CPU neighbors. Every other node
// keeps its relative distance to the sink boundary and only absorbs the
// signed DCPL delta captured between `oldDcpl` and the now-updated `dcpl`.
void GraphDCP::updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl) {
Time newDcpl = getDcpl();
// If the AEST update saturated dcpl (e.g. rescue placement on a
// crossbar-exhausted CPU sets task weight to UINT64_MAX), the shift delta
// would be meaningless. Fall back to a full recompute for this step only.
if (newDcpl == std::numeric_limits<Time>::max()) {
initAlst();
return;
}
if (newDcpl != oldDcpl) {
const bool increased = newDcpl > oldDcpl;
const Time delta = increased ? (newDcpl - oldDcpl) : (oldDcpl - newDcpl);
for (TaskDCP& node : topologicalOrder) {
if (&node == task || relations.ancestors.contains(&node))
continue;
Time alst = node.getAlst();
node.setAlst(increased ? addOrMax(alst, delta) : subtractOrZero(alst, delta));
}
}
auto recomputeAlst = [&](TaskDCP* node) {
Time minAlst = std::numeric_limits<Time>::max();
if (!node->hasChildren())
minAlst = subtractOrZero(newDcpl, node->getWeight());
for (const Edge& childEdge : node->children)
minAlst = std::min(minAlst,
subtractOrZero(childEdge.first->getAlst(),
addOrMax(node->getWeight(), getTransferCost(node, childEdge.first))));
node->setAlst(minAlst);
};
// Walk the backward cone with a pending-children counter so that every
// ancestor is recomputed only after all of its affected children have
// been refreshed. This is resilient to staleness in the global
// `topologicalOrder` relative to freshly added scheduling edges.
llvm::DenseSet<TaskDCP*> affected = relations.ancestors;
affected.insert(task);
llvm::DenseMap<TaskDCP*, int> pendingAffectedChildren;
pendingAffectedChildren.reserve(affected.size());
std::vector<TaskDCP*> worklist;
worklist.reserve(affected.size());
for (TaskDCP* node : affected) {
int count = 0;
for (const Edge& childEdge : node->children)
if (affected.contains(childEdge.first))
count++;
pendingAffectedChildren[node] = count;
if (count == 0)
worklist.push_back(node);
}
while (!worklist.empty()) {
TaskDCP* node = worklist.back();
worklist.pop_back();
recomputeAlst(node);
for (const Edge& parentEdge : node->parents) {
if (!affected.contains(parentEdge.first))
continue;
auto it = pendingAffectedChildren.find(parentEdge.first);
assert(it != pendingAffectedChildren.end());
if (--it->second == 0)
worklist.push_back(parentEdge.first);
}
}
// Opt-in consistency check: verifies the incremental ALST result against a
// full initAlst() recomputation. Very expensive (O(V+E) per placement) - only
// enable when investigating suspected drift.
#ifdef DCP_DEBUG_CHECK_ALST
std::vector<Time> afterIncremental(nodes.size());
for (size_t i = 0; i < nodes.size(); ++i)
afterIncremental[i] = nodes[i].getAlst();
initAlst();
bool mismatched = false;
for (size_t i = 0; i < nodes.size(); ++i) {
if (afterIncremental[i] != nodes[i].getAlst()) {
if (!mismatched) {
llvm::errs() << "[alst-mismatch] placed=" << getNodeIndex(task) << " oldDcpl=" << oldDcpl
<< " newDcpl=" << newDcpl << " ancestors={";
for (TaskDCP* a : relations.ancestors)
llvm::errs() << getNodeIndex(a) << ",";
llvm::errs() << "}\n";
mismatched = true;
}
llvm::errs() << " node=" << i << " incremental=" << afterIncremental[i] << " full=" << nodes[i].getAlst()
<< " weight=" << nodes[i].getWeight()
<< " cpu=" << (nodes[i].isScheduled() ? (int) *nodes[i].getCpu() : -1) << " children=[";
for (const Edge& e : nodes[i].children)
llvm::errs() << getNodeIndex(e.first) << (e.isScheduling ? "s" : "")
<< "(tc=" << getTransferCost(&nodes[i], e.first) << ",alst=" << e.first->getAlst() << "),";
llvm::errs() << "]\n";
}
}
#endif
}
// Computes a localised ALST: only ancestors of the candidate (plus the // Computes a localised ALST: only ancestors of the candidate (plus the
// candidate itself) get recomputed, every other task keeps its current ALST. // candidate itself) get recomputed, every other task keeps its current ALST.
// Processes nodes in reverse dependency order using a pending-children // Processes nodes in reverse dependency order using a pending-children
@@ -906,32 +1177,6 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
// Candidate selection and processor assignment // Candidate selection and processor assignment
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Lowest slack wins; earliest AEST breaks ties. Critical-path tasks (zero
// slack) naturally float to the front.
TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
auto findBestNode = [](auto lft, auto rgt) {
Time leftSlack = slackOrZero((*lft)->getAest(), (*lft)->getAlst());
Time rightSlack = slackOrZero((*rgt)->getAest(), (*rgt)->getAlst());
if (leftSlack < rightSlack)
return lft;
if (rightSlack < leftSlack)
return rgt;
if ((*lft)->getAest() < (*rgt)->getAest())
return lft;
return rgt;
};
assert(!readyNodes.empty() && "expected at least one ready node");
auto validNode = readyNodes.begin();
auto bestNode = validNode;
while (validNode != readyNodes.end()) {
bestNode = findBestNode(validNode, bestNode);
std::advance(validNode, 1);
}
return *bestNode;
}
// Picks the best CPU + slot for `candidate`: // Picks the best CPU + slot for `candidate`:
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU. // * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For // * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
@@ -940,7 +1185,7 @@ TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
// evaluate a slot for the smallest-slack child, then roll back. // evaluate a slot for the smallest-slack child, then roll back.
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed, // * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
// otherwise pick the CPU that leads to the smallest DCPL increase. // otherwise pick the CPU that leads to the smallest DCPL increase.
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) { GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate); CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
relations.descendantsTopoOrder.reserve(relations.descendants.size()); relations.descendantsTopoOrder.reserve(relations.descendants.size());
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) { for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
@@ -960,22 +1205,43 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate); const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
const bool candidateHasCrossbar = candidateFootprint != 0; const bool candidateHasCrossbar = candidateFootprint != 0;
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0; const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
DCP_DEBUG_IF(auto dedupStart = std::chrono::steady_clock::now();)
CpuAestCache cpuAests = computeCpuAestCache(candidate);
DCP_DEBUG_IF(const bool checkCpuAestCache = std::getenv("DCP_CHECK_CPU_AEST_CACHE") != nullptr;)
llvm::SmallDenseSet<uint64_t, 32> seenProcessorKeys;
seenProcessorKeys.reserve(static_cast<size_t>(topCpu + 1));
for (CPU c = 0; c <= topCpu; c++) { for (CPU c = 0; c <= topCpu; c++) {
if (candidateHasCrossbar && c != getLastCpu()) { if (candidateHasCrossbar && c != getLastCpu()) {
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint); CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
if (nextUsage >= cpuCapacity) if (nextUsage >= cpuCapacity)
continue; continue;
} }
Time candidateAest = cpuAests.get(c);
DCP_DEBUG_IF(if (checkCpuAestCache) {
Time recomputedAest = computeAestOnCpu(candidate, c);
if (candidateAest != recomputedAest) {
std::fprintf(stderr,
"[DCP_CHECK_CPU_AEST_CACHE] mismatch candidate=%zu cpu=%d cached=%llu recomputed=%llu\n",
getNodeIndex(candidate),
c,
static_cast<unsigned long long>(candidateAest),
static_cast<unsigned long long>(recomputedAest));
llvm::report_fatal_error("DCP CPU AEST cache mismatch");
}
})
if (!seenProcessorKeys.insert(computeCpuCandidateKey(candidateAest, c)).second)
continue;
processors.push_back(c); processors.push_back(c);
} }
DCP_DEBUG_IF(gSelectTimers.dedup +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - dedupStart).count();)
if (processors.empty()) { if (processors.empty()) {
CPU bestCpu = canCreateNewCpu ? getLastCpu() : 0; // processors.empty() implies !canCreateNewCpu: a fresh CPU always passes
FindSlot bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())}; // the crossbar filter and would have been added. Reaching here means every
if (canCreateNewCpu) // existing CPU is crossbar-exhausted and the task requires crossbar
incrementLastCpu(); // capacity — the placement is impossible.
insertTaskInCPU(bestCpu, candidate, bestSlot.index); llvm::report_fatal_error("DCP scheduler: crossbar capacity exhausted on all CPUs; "
return; "cannot schedule task that requires crossbar allocation");
} }
// Phase 1: parallel findSlot sweep (read-only over graph state). // Phase 1: parallel findSlot sweep (read-only over graph state).
@@ -1008,14 +1274,13 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
static bool reported = false; static bool reported = false;
if (!reported) { if (!reported) {
reported = true; reported = true;
std::fprintf(stderr, std::fprintf(
stderr,
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n", "[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
(void*) context, (void*) context,
context != nullptr ? (int) context->isMultithreadingEnabled() : -1, context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
processors.size(), processors.size(),
context != nullptr && context->isMultithreadingEnabled() context != nullptr && context->isMultithreadingEnabled() ? context->getThreadPool().getMaxConcurrency() : 0u);
? context->getThreadPool().getMaxConcurrency()
: 0u);
} }
} }
#endif #endif
@@ -1056,9 +1321,10 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();) DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu); Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
Time candidateCompletion = addOrMax(slot.aest, candidateWeight); Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
bool skip = (!emptyCpu && candidateCompletion > currentDcpl) bool skip =
|| addOrMax(slot.aest, candidateCompletion) >= bestComposite; (!emptyCpu && candidateCompletion > currentDcpl) || addOrMax(slot.aest, candidateCompletion) >= bestComposite;
DCP_DEBUG_IF(gSelectTimers.precheck += std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();) DCP_DEBUG_IF(gSelectTimers.precheck +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
if (skip) if (skip)
continue; continue;
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;) DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
@@ -1074,8 +1340,8 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
scheduleSnapshot = dcp_graph::captureLocalScheduleState( scheduleSnapshot = dcp_graph::captureLocalScheduleState(
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask); candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index); taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
bool withinBudget = tryUpdateAestWithinBudget( bool withinBudget =
candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl); tryUpdateAestWithinBudget(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
if (!withinBudget) { if (!withinBudget) {
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();) DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
taskInsertion.rollBack(); taskInsertion.rollBack();
@@ -1151,7 +1417,9 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
else { else {
Time bestDcpl = std::numeric_limits<Time>::max(); Time bestDcpl = std::numeric_limits<Time>::max();
Time currentDcpl = getDcpl(); Time currentDcpl = getDcpl();
for (CPU c = 0; c < getLastCpu(); c++) { for (CPU c : processors) {
if (c == getLastCpu())
continue;
auto slot = findSlot(candidate, c, false, relations); auto slot = findSlot(candidate, c, false, relations);
if (slot.aest == std::numeric_limits<Time>::max()) if (slot.aest == std::numeric_limits<Time>::max())
slot = findSlot(candidate, c, true, relations); slot = findSlot(candidate, c, true, relations);
@@ -1160,8 +1428,7 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl, // Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
// candidate completion on this slot). Skip CPUs already worse than // candidate completion on this slot). Skip CPUs already worse than
// the best seen. // the best seen.
Time lowerBound = Time lowerBound = std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
if (lowerBound >= bestDcpl) if (lowerBound >= bestDcpl)
continue; continue;
auto snapshot = dcp_graph::captureLocalScheduleState( auto snapshot = dcp_graph::captureLocalScheduleState(
@@ -1170,23 +1437,37 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder)); updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
Time candidateDcpl = getDcpl(); Time candidateDcpl = getDcpl();
taskInsertion.rollBack(); taskInsertion.rollBack();
dcp_graph::restoreLocalScheduleState( dcp_graph::restoreLocalScheduleState(snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
if (candidateDcpl < bestDcpl) { if (candidateDcpl < bestDcpl) {
bestDcpl = candidateDcpl; bestDcpl = candidateDcpl;
bestCpu = c; bestCpu = c;
bestSlot = slot; bestSlot = slot;
} }
} }
if (bestCpu == -1) { if (bestCpu == -1)
bestCpu = 0; llvm::report_fatal_error("DCP scheduler: no valid slot found for task on any eligible CPU — "
bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())}; "all slots are blocked by already-placed descendants");
}
} }
} }
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount) if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
incrementLastCpu(); incrementLastCpu();
insertTaskInCPU(bestCpu, candidate, bestSlot.index); insertTaskInCPU(bestCpu, candidate, bestSlot.index);
// Incremental AEST/ALST refresh replacing the full initAest/initAlst that
// used to run after every placement. Post-insertion relations pick up any
// new scheduling-edge ancestors/descendants introduced by the insertion.
Time oldDcpl = getDcpl();
CandidateRelations postRelations = dcp_graph::computeCandidateRelations(candidate);
llvm::SmallVector<TaskDCP*, 32> postDescendantsTopoOrder;
postDescendantsTopoOrder.reserve(postRelations.descendants.size());
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
TaskDCP* current = &*it;
if (current != candidate && postRelations.descendants.contains(current))
postDescendantsTopoOrder.push_back(current);
}
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(postDescendantsTopoOrder));
updateAlstFromScheduledTask(candidate, postRelations, oldDcpl);
return postRelations;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -1195,61 +1476,102 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
void GraphDCP::runDcp() { void GraphDCP::runDcp() {
initTopological(); initTopological();
initTaskStructureHashes();
initAest(); initAest();
initAlst(); initAlst();
dumpDot(); dumpDot();
dcp_graph::DcpProgressLogger progressLogger(nodes.size()); dcp_graph::DcpProgressLogger progressLogger(nodes.size());
llvm::DenseMap<TaskDCP*, int> unscheduledParents; llvm::DenseMap<TaskDCP*, int> unscheduledParents;
std::vector<TaskDCP*> readyNodes;
readyNodes.reserve(nodes.size()); // Min-heap over ready tasks: tightest slack first, earliest AEST as tiebreak.
// Lazy deletion: when AEST/ALST change after a placement, fresh entries are
// pushed for the affected tasks. Stale ones are detected on pop by comparing
// stored vs current (slack, aest) and re-pushed with the current values.
struct ReadyEntry {
Time slack;
Time aest;
int64_t orderKey;
TaskDCP* task;
bool operator>(const ReadyEntry& other) const {
if (slack != other.slack)
return slack > other.slack;
if (aest != other.aest)
return aest > other.aest;
return orderKey > other.orderKey;
}
};
std::priority_queue<ReadyEntry, std::vector<ReadyEntry>, std::greater<ReadyEntry>> readyQueue;
size_t readyCount = 0;
auto pushReady = [&](TaskDCP* node) {
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node->Id(), node});
};
for (auto& node : nodes) { for (auto& node : nodes) {
int dependencyParents = dcp_graph::countDependencyParents(&node); int dependencyParents = dcp_graph::countDependencyParents(&node);
unscheduledParents[&node] = dependencyParents; unscheduledParents[&node] = dependencyParents;
if (dependencyParents == 0) if (dependencyParents == 0) {
readyNodes.push_back(&node); pushReady(&node);
++readyCount;
} }
progressLogger.printStart(readyNodes.size()); }
size_t xbarsCapacity = static_cast<size_t>(maxCpuCount) * onnx_mlir::crossbarCountInCore.getValue();
progressLogger.printStart(readyCount, maxCpuCount, xbarsCapacity);
while (!readyNodes.empty()) { while (readyCount > 0) {
DCP_DEBUG_IF(auto findStart = std::chrono::steady_clock::now();) // Pop with lazy deletion: skip stale entries and re-push with current values.
TaskDCP* candidate = findCandidate(readyNodes); TaskDCP* candidate = nullptr;
DCP_DEBUG_IF(progressLogger.recordFindDuration( while (!readyQueue.empty()) {
std::chrono::duration<double>(std::chrono::steady_clock::now() - findStart).count());) auto entry = readyQueue.top();
fastRemove(readyNodes, candidate); readyQueue.pop();
Time curSlack = slackOrZero(entry.task->getAest(), entry.task->getAlst());
Time curAest = entry.task->getAest();
if (entry.slack == curSlack && entry.aest == curAest) {
candidate = entry.task;
break;
}
readyQueue.push({curSlack, curAest, entry.orderKey, entry.task});
}
assert(candidate != nullptr && "readyCount > 0 but heap exhausted");
--readyCount;
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();) DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
selectProcessor(candidate, candidate->isCriticalPath()); CandidateRelations postRelations = selectProcessor(candidate, candidate->isCriticalPath());
DCP_DEBUG_IF( DCP_DEBUG_IF(
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count(); double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
progressLogger.recordSelectDuration(selectSeconds); progressLogger.recordSelectDuration(selectSeconds);
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyNodes.size(), getLastCpu()); progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyCount, getLastCpu());)
)
// Proactively refresh the heap priority for ready nodes whose AEST or ALST
// changed: ancestors had their ALST individually recomputed; descendants had
// their AEST bumped. Both may now sort differently than their stale entries.
for (TaskDCP* node : postRelations.ancestors)
if (!node->isScheduled() && unscheduledParents[node] == 0)
pushReady(node);
for (TaskDCP* node : postRelations.descendants)
if (!node->isScheduled() && unscheduledParents[node] == 0)
pushReady(node);
DCP_DEBUG_IF(auto updateStart = std::chrono::steady_clock::now();)
initAest();
initAlst();
DCP_DEBUG_IF(progressLogger.recordUpdateDuration(
std::chrono::duration<double>(std::chrono::steady_clock::now() - updateStart).count());)
progressLogger.advanceCompleted(); progressLogger.advanceCompleted();
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "recompute", false); progressLogger.printProgress(readyCount, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), false);
for (const auto& childEdge : candidate->children) { for (const auto& childEdge : candidate->children) {
if (childEdge.isScheduling || childEdge.first->isScheduled()) if (childEdge.isScheduling || childEdge.first->isScheduled())
continue; continue;
int& dependencyParents = unscheduledParents[childEdge.first]; int& dependencyParents = unscheduledParents[childEdge.first];
assert(dependencyParents > 0 && "dependency parent count must stay positive"); assert(dependencyParents > 0 && "dependency parent count must stay positive");
dependencyParents--; --dependencyParents;
if (dependencyParents == 0) if (dependencyParents == 0) {
readyNodes.push_back(childEdge.first); pushReady(childEdge.first);
++readyCount;
} }
DCP_DEBUG_IF( }
++gSelectTimers.tasksProcessed; DCP_DEBUG_IF(++gSelectTimers.tasksProcessed;
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0)) if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
gSelectTimers.dump("tick"); gSelectTimers.dump("tick");)
)
} }
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "done", true); progressLogger.printProgress(0, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), true);
dumpDot(); dumpDot();
} }
@@ -1260,8 +1582,11 @@ DCPAnalysisResult GraphDCP::getResult() {
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size()); auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
ret.dominanceOrderCompute.reserve(dominanceOrder.size()); ret.dominanceOrderCompute.reserve(dominanceOrder.size());
for (auto elem : dominanceOrder) for (auto elem : dominanceOrder) {
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute()); auto spatCompute = elem->getSpatCompute();
if (spatCompute)
ret.dominanceOrderCompute.push_back({spatCompute.getOperation(), 0});
}
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) { for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
const CpuTaskList* tasks = findCpuTasks(cpu); const CpuTaskList* tasks = findCpuTasks(cpu);
@@ -1269,10 +1594,14 @@ DCPAnalysisResult GraphDCP::getResult() {
continue; continue;
size_t i = 0; size_t i = 0;
for (auto node : *tasks) { for (auto node : *tasks) {
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu; auto spatCompute = node->getSpatCompute();
if (!spatCompute)
continue;
ComputeInstance instance {spatCompute.getOperation(), 0};
ret.computeToCpuMap[instance] = cpu;
if (i++ == tasks->size() - 1) { if (i++ == tasks->size() - 1) {
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute()); ret.isLastComputeOfCpu.insert(instance);
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); ret.cpuToLastComputeMap[cpu] = instance;
} }
} }
} }

View File

@@ -4,6 +4,7 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <list> #include <list>
#include <optional> #include <optional>
#include <unordered_map> #include <unordered_map>
@@ -48,8 +49,10 @@ private:
std::vector<TaskDCP> nodes; std::vector<TaskDCP> nodes;
onnx_mlir::LabeledList<TaskDCP> topologicalOrder; onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
std::vector<uint64_t> taskStructureHashes;
std::vector<CpuTaskList> cpuTasks; std::vector<CpuTaskList> cpuTasks;
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage; std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
llvm::DenseMap<CPU, uint64_t> cpuStructureHashes;
CPU lastCpu = 0; CPU lastCpu = 0;
long long flag = 1; long long flag = 1;
Time dcpl = 0; Time dcpl = 0;
@@ -70,6 +73,7 @@ private:
void initAest(); void initAest();
void initAlst(); void initAlst();
void initTaskStructureHashes();
Time computeAestOnCpu(TaskDCP* task, CPU cpu); Time computeAestOnCpu(TaskDCP* task, CPU cpu);
Time computeDcplOnCpu(TaskDCP* task, CPU cpu); Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
@@ -83,9 +87,15 @@ private:
// `dcplBudget`, signalling that the new DCPL would exceed the budget. // `dcplBudget`, signalling that the new DCPL would exceed the budget.
// Returns true iff the full propagation completed without exceeding the // Returns true iff the full propagation completed without exceeding the
// budget. Uses the caller's snapshot to restore AEST on the aborted tail. // budget. Uses the caller's snapshot to restore AEST on the aborted tail.
bool tryUpdateAestWithinBudget(TaskDCP* task, bool tryUpdateAestWithinBudget(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder, Time dcplBudget);
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
Time dcplBudget); // Incrementally refreshes ALST after `task` has been scheduled. Nodes
// outside the backward cone (`relations.ancestors` plus `task`) retain
// their relative distance to the sink boundary and only absorb the signed
// DCPL delta (`newDcpl - oldDcpl`). `task` itself and its ancestors are
// recomputed in reverse topological order so that new same-CPU transfer
// costs (now zero) and scheduling-edge children are reflected.
void updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl);
void initTopological(); void initTopological();
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr); void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
@@ -94,8 +104,11 @@ private:
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations); llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
size_t getNodeIndex(const TaskDCP* task) const; size_t getNodeIndex(const TaskDCP* task) const;
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes); // Returns a compact dedup key for CPU `c` when evaluating `candidate`:
void selectProcessor(TaskDCP* candidate, bool push); // mixes candidateAest, crossbar usage, and the incremental cpu structure
// hash into a single uint64_t. Zero heap allocation.
uint64_t computeCpuCandidateKey(Time candidateAest, CPU cpu);
CandidateRelations selectProcessor(TaskDCP* candidate, bool push);
CPU getLastCpu() const { return lastCpu; } CPU getLastCpu() const { return lastCpu; }
void incrementLastCpu() { lastCpu++; } void incrementLastCpu() { lastCpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations); FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
@@ -115,24 +128,28 @@ private:
public: public:
void runDcp(); void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes, GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes, llvm::ArrayRef<IndexedEdge> edges)
llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() { : nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatWeightedCompute : spatWeightedComputes) for (auto spatCompute : spatComputes)
nodes.emplace_back(spatWeightedCompute); nodes.emplace_back(spatCompute);
for (auto [start, end, weight] : edges) for (auto [start, end, weight] : edges)
makeEdge(start, end, weight); makeEdge(start, end, weight);
} }
GraphDCP(llvm::ArrayRef<Weight> nodeWeights, GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
llvm::ArrayRef<IndexedEdge> edges, llvm::ArrayRef<IndexedEdge> edges,
llvm::ArrayRef<int64_t> nodeOrderKeys = {},
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {}) llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
: nodes(), cpuTasks(), cpuCrossbarUsage() { : nodes(), cpuTasks(), cpuCrossbarUsage() {
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size()) assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
&& "synthetic crossbar usage must match synthetic node weights"); && "synthetic crossbar usage must match synthetic node weights");
assert((nodeOrderKeys.empty() || nodeOrderKeys.size() == nodeWeights.size())
&& "synthetic node order keys must match synthetic node weights");
nodes.reserve(nodeWeights.size()); nodes.reserve(nodeWeights.size());
for (auto [index, weight] : llvm::enumerate(nodeWeights)) for (auto [index, weight] : llvm::enumerate(nodeWeights))
nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]); nodes.emplace_back(nodeOrderKeys.empty() ? static_cast<int64_t>(index) : nodeOrderKeys[index],
weight,
nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
for (auto [start, end, weight] : edges) for (auto [start, end, weight] : edges)
makeEdge(start, end, weight); makeEdge(start, end, weight);
} }
@@ -150,6 +167,11 @@ public:
void setMaxCpuCount(int value) { maxCpuCount = value; } void setMaxCpuCount(int value) { maxCpuCount = value; }
int getMaxCpuCount() const { return maxCpuCount; } int getMaxCpuCount() const { return maxCpuCount; }
// Total crossbar units allocated across all active CPUs.
size_t crossbarsUsed() const;
// Maximum crossbar units available across all active CPUs (lastCpu * per-CPU capacity).
size_t crossbarsAvailable() const;
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If // Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
// null the scheduler runs single-threaded (tests use this path). // null the scheduler runs single-threaded (tests use this path).
void setContext(mlir::MLIRContext* ctx) { context = ctx; } void setContext(mlir::MLIRContext* ctx) { context = ctx; }

View File

@@ -35,10 +35,12 @@ void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSe
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; } void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; } void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
void DcpProgressLogger::printStart(size_t readyCount) const { void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
if (!logProgress) if (!logProgress)
return; return;
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount); llvm::errs() << llvm::formatv(
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
} }
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex, void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
@@ -48,14 +50,15 @@ void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
if (!logProgress || elapsedSeconds < 1.0) if (!logProgress || elapsedSeconds < 1.0)
return; return;
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n", llvm::errs() << llvm::formatv("[DCP] slow node={0} elapsed={1} ready={2} cpus={3}\n",
nodeIndex, nodeIndex,
formatDuration(elapsedSeconds), formatDuration(elapsedSeconds),
readyCount, readyCount,
cpuCount); cpuCount);
} }
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) { void DcpProgressLogger::printProgress(
size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force) {
if (!logProgress) if (!logProgress)
return; return;
@@ -68,19 +71,19 @@ void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::Str
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0; double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks); double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n", bool done = completedTasks == totalTasks;
llvm::errs() << llvm::formatv(
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks, completedTasks,
totalTasks, totalTasks,
percent, percent,
readyCount, readyCount,
cpuCount, cpuCount,
stage, maxCpuCount,
formatDuration(elapsedSeconds), xbarsUsed,
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds)); xbarsAvailable,
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n", llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
formatDuration(findCandidateSeconds), done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
formatDuration(selectProcessorSeconds),
formatDuration(updateTimingSeconds));
lastProgressPrint = now; lastProgressPrint = now;
} }
@@ -91,9 +94,9 @@ void DcpProgressLogger::recordFindDuration(double) {}
void DcpProgressLogger::recordSelectDuration(double) {} void DcpProgressLogger::recordSelectDuration(double) {}
void DcpProgressLogger::recordUpdateDuration(double) {} void DcpProgressLogger::recordUpdateDuration(double) {}
void DcpProgressLogger::advanceCompleted(size_t) {} void DcpProgressLogger::advanceCompleted(size_t) {}
void DcpProgressLogger::printStart(size_t) const {} void DcpProgressLogger::printStart(size_t, int, size_t) const {}
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {} void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {} void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
#endif #endif

View File

@@ -31,9 +31,10 @@ public:
void recordUpdateDuration(double seconds); void recordUpdateDuration(double seconds);
void advanceCompleted(size_t taskCount = 1); void advanceCompleted(size_t taskCount = 1);
void printStart(size_t readyCount) const; void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const; void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force); void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
size_t xbarsUsed, size_t xbarsAvailable, bool force);
#ifdef DCP_DEBUG_ENABLED #ifdef DCP_DEBUG_ENABLED
private: private:

View File

@@ -8,7 +8,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> { class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute; onnx_mlir::spatial::SpatCompute spatCompute;
Time aest; Time aest;
Time alst; Time alst;
std::optional<CPU> scheduledCpu; std::optional<CPU> scheduledCpu;
@@ -38,22 +38,22 @@ public:
std::vector<Edge> parents; std::vector<Edge> parents;
std::vector<Edge> children; std::vector<Edge> children;
TaskDCP() = default; TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(), : onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(spatWeightedCompute), spatCompute(spatCompute),
aest(0), aest(0),
alst(0), alst(0),
scheduledCpu(), scheduledCpu(),
weight(getSpatComputeWeight(spatWeightedCompute)), weight(getSpatComputeWeight(spatCompute)),
baseWeight(weight), baseWeight(weight),
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)), crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
syntheticId(-1), syntheticId(-1),
parents(), parents(),
children() {} children() {}
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0) TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
: onnx_mlir::LabeledListNode<TaskDCP>(), : onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(), spatCompute(),
aest(0), aest(0),
alst(0), alst(0),
scheduledCpu(), scheduledCpu(),
@@ -90,14 +90,14 @@ public:
void setAlst(Time value) { alst = value; } void setAlst(Time value) { alst = value; }
bool hasDescendant(TaskDCP* child); bool hasDescendant(TaskDCP* child);
int64_t Id() const { int64_t Id() const {
if (spatWeightedCompute) if (spatCompute)
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer()); return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
return syntheticId; return syntheticId;
} }
bool isCriticalPath() const { return alst == aest; } bool isCriticalPath() const { return alst == aest; }
bool isScheduled() const { return scheduledCpu.has_value(); } bool isScheduled() const { return scheduledCpu.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; } onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
void setFlag(long long val) { flag = val; } void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; } long long getFlag() const { return flag; }

View File

@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); } inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
constexpr Weight kOperationWeight = 100; constexpr Weight kOperationWeight = 100;
Weight numOperations = 0; Weight numOperations = 0;
for (auto& block : spatWeightedCompute.getBody()) for (auto& block : spatCompute.getBody())
for ([[maybe_unused]] auto& op : block) for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1)); numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight); return checkedMultiply(numOperations, kOperationWeight);
} }
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
for (auto& region : spatWeightedCompute.getBody()) for (auto& region : spatCompute.getBody())
for (auto& inst : region) for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst)) if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1)); crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));

View File

@@ -31,7 +31,7 @@ struct CountInstructionPass : public PassWrapper<CountInstructionPass, Operation
unsigned totalInstructionCount = 0; unsigned totalInstructionCount = 0;
unsigned computeId = 0; unsigned computeId = 0;
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) { for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
unsigned instructionCount = 0; unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size(); instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n"; llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";

View File

@@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill"); auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(), mapOp.getLoc(),
@@ -258,9 +257,18 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
// Look through an optional pim.memcp_hd to find the source get_global.
// This occurs when the constant was staged into device memory before transposing.
pim::PimMemCopyHostToDevOp memcpHd;
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>(); auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal) {
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
if (!memcpHd)
return failure();
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal) if (!sourceGetGlobal)
return failure(); return failure();
}
auto moduleOp = transposeOp->getParentOfType<ModuleOp>(); auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp) if (!moduleOp)
@@ -298,13 +306,26 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
bool isAlwaysWeight = bool isAlwaysWeight =
!transposeOp->getUsers().empty() !transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }); && llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
});
if (isAlwaysWeight) { if (isAlwaysWeight) {
markWeightAlways(newGlobal); markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal); markWeightAlways(newGetGlobal);
} }
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
rewriter.replaceOp(transposeOp, newGetGlobal.getResult()); rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
if (memcpHd && memcpHd.use_empty()) {
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
rewriter.eraseOp(memcpHd);
if (deviceAllocOp && deviceAllocOp->use_empty())
rewriter.eraseOp(deviceAllocOp);
}
if (outputAllocOp && outputAllocOp->use_empty())
rewriter.eraseOp(outputAllocOp);
return success(); return success();
} }
}; };
@@ -341,18 +362,25 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
continue; continue;
} }
if (!isa<pim::PimCoreOp>(user)) if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
return failure(); return failure();
} }
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }); return llvm::all_of(castOp->getUsers(), [](Operation* user) {
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
});
})) { })) {
allLiveUsersAreCoreOps = false; allLiveUsersAreCoreOps = false;
} }
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) { if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user); return isa<linalg::MapOp,
memref::SubViewOp,
memref::DeallocOp,
memref::CastOp,
pim::PimCoreOp,
pim::PimCoreBatchOp>(user);
})) { })) {
return failure(); return failure();
} }
@@ -389,6 +417,83 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
} }
}; };
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> { struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -443,7 +548,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
continue; continue;
if (isa<memref::DeallocOp>(user)) if (isa<memref::DeallocOp>(user))
continue; continue;
if (isa<pim::PimCoreOp>(user)) if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
continue; continue;
if (isa<memref::SubViewOp>(user)) { if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false; allLiveUsersAreCores = false;
@@ -473,7 +578,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns patterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>( .add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantHostCopyPattern,
FoldConstantMemCpPattern>(
patterns.getContext()); patterns.getContext());
} }

View File

@@ -24,7 +24,26 @@ static bool isAddressOnlyHostOp(Operation* op) {
memref::CastOp, memref::CastOp,
memref::CollapseShapeOp, memref::CollapseShapeOp,
memref::ExpandShapeOp, memref::ExpandShapeOp,
spatial::SpatChannelNewOp>(op); memref::CopyOp>(op);
}
// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity.
// Used for memref.copy operands which may be non-contiguous subviews.
static bool isBaseAddressableValue(Value value) {
while (true) {
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
return true;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
return false;
}
} }
static bool isCodegenAddressableValue(Value value) { static bool isCodegenAddressableValue(Value value) {
@@ -38,6 +57,8 @@ static bool isCodegenAddressableValue(Value value) {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op)) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1; return operandIndex == 1;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op)) if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0; return operandIndex == 0;
return false; return false;
@@ -69,6 +90,12 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
continue; continue;
} }
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
hasFailure = true;
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) { if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp))) if (failed(verifyReturnOp(returnOp)))
hasFailure = true; hasFailure = true;
@@ -92,10 +119,11 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
} }
private: private:
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) { template <typename CoreOpTy>
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
bool hasFailure = false; bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>(); auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) { if (!getGlobalOp) {
coreOp.emitOpError() << "weight #" << weightIndex coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen"; << " must be materialized as memref.get_global before JSON codegen";
@@ -131,7 +159,8 @@ private:
return success(!hasFailure); return success(!hasFailure);
} }
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) { template <typename CoreOpTy>
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
return walkPimCoreBlock( return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false; bool hasFailure = false;
@@ -174,6 +203,13 @@ private:
return verifyAddressOnlySource(op, collapseOp.getSrc()); return verifyAddressOnlySource(op, collapseOp.getSrc());
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op)) if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc()); return verifyAddressOnlySource(op, expandOp.getSrc());
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
op->emitOpError("depends on a value that is not backed by addressable storage");
return failure();
}
return success();
}
return success(); return success();
} }

View File

@@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
add_custom_target(pim-unittest) add_custom_target(pim-unittest)
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests") set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")

View File

@@ -457,6 +457,10 @@ int testDCPGraphDiamondDependencies() {
return 0; return 0;
} }
// crossbarSize=4, crossbarCount=2 => capacity = 4*4*2 = 32.
// Each task with crossbarUsage=1 needs footprint = 4*4 = 16, so at most 1 task
// can fit per CPU (16+16 = 32 >= capacity). The scheduler must open a fresh CPU
// for each task; all three end up on separate CPUs with their base weight.
int testDCPGraphCrossbarExhaustion() { int testDCPGraphCrossbarExhaustion() {
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl; std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
configureDcpDotOutput(); configureDcpDotOutput();
@@ -473,37 +477,36 @@ int testDCPGraphCrossbarExhaustion() {
const std::vector<Weight> nodeWeights = {10, 10, 10}; const std::vector<Weight> nodeWeights = {10, 10, 10};
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1}; const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
GraphDCP graph(nodeWeights, {}, nodeCrossbarUsage); GraphDCP graph(nodeWeights, {}, {},nodeCrossbarUsage);
graph.setMaxCpuCount(1); graph.setMaxCpuCount(3);
graph.runDcp(); graph.runDcp();
if (graph.cpuCount() != 1) { if (graph.cpuCount() != 3) {
restoreCrossbarOptions(); restoreCrossbarOptions();
std::cerr << "Expected exactly 1 CPU with maxCpuCount=1, got " << graph.cpuCount() << "\n"; std::cerr << "Expected 3 CPUs (one per task due to crossbar limit), got " << graph.cpuCount() << "\n";
dumpDcpFailureArtifacts(); dumpDcpFailureArtifacts();
return 1; return 1;
} }
auto scheduledTasks = graph.getScheduledTasks(0); int failures = 0;
if (scheduledTasks.size() != 3) { for (CPU c = 0; c < 3; c++) {
restoreCrossbarOptions(); auto scheduledTasks = graph.getScheduledTasks(c);
std::cerr << "Expected all three tasks to be scheduled on CPU 0\n"; if (scheduledTasks.size() != 1) {
printCpuSchedule(graph, 0); std::cerr << "Expected exactly 1 task on CPU " << c << ", got " << scheduledTasks.size() << "\n";
dumpDcpFailureArtifacts(); printCpuSchedule(graph, c);
return 1; failures++;
continue;
}
if (scheduledTasks[0].weight != 10) {
std::cerr << "Expected weight=10 on CPU " << c << ", got " << scheduledTasks[0].weight << "\n";
printCpuSchedule(graph, c);
failures++;
} }
if (scheduledTasks[0].weight != 10 || scheduledTasks[1].weight != std::numeric_limits<Weight>::max()
|| scheduledTasks[2].weight != std::numeric_limits<Weight>::max()) {
restoreCrossbarOptions();
std::cerr << "Unexpected effective weights under crossbar exhaustion\n";
printCpuSchedule(graph, 0);
dumpDcpFailureArtifacts();
return 1;
} }
restoreCrossbarOptions(); restoreCrossbarOptions();
return 0; if (failures) dumpDcpFailureArtifacts();
return failures;
} }
} // namespace } // namespace

View File

@@ -26,6 +26,10 @@ STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
def sanitize_output_name(name):
return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255])
@dataclass @dataclass
class ValidationResult: class ValidationResult:
passed: bool passed: bool
@@ -33,7 +37,7 @@ class ValidationResult:
class ProgressReporter: class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT): def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None):
self.total_models = total_models self.total_models = total_models
self.stages_per_model = stages_per_model self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model) self.total_steps = max(1, total_models * stages_per_model)
@@ -41,7 +45,7 @@ class ProgressReporter:
self.passed_models = 0 self.passed_models = 0
self.failed_models = 0 self.failed_models = 0
self.current_label = "" self.current_label = ""
self.enabled = True self.enabled = sys.stdout.isatty() if enabled is None else enabled
self.columns = shutil.get_terminal_size((100, 20)).columns self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False self.suspended = False
@@ -205,7 +209,7 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command( run_command(
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, cwd=simulator_dir,
reporter=reporter, reporter=reporter,
@@ -229,7 +233,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
all_passed = True all_passed = True
rows = [] rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
csv_name = f"output{oi}_{name}.csv" csv_name = f"output{oi}_{sanitize_output_name(name)}.csv"
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
passed = max_diff <= threshold passed = max_diff <= threshold