better deadlock report by pim simulator
This commit is contained in:
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
|||||||
run only the codegen tail.
|
run only the codegen tail.
|
||||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||||
per-core count.
|
per-core count.
|
||||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||||
|
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||||
|
merge-compute-nodes pass (default: `peft`).
|
||||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||||
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
|
|||||||
```
|
```
|
||||||
validate.py \
|
validate.py \
|
||||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
--onnx-include-dir ../onnx-mlir/include
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--core-count 1000
|
||||||
```
|
```
|
||||||
|
|
||||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
|
|||||||
.lock()
|
.lock()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.init(executor.cpu().num_core(), args.output.clone());
|
.init(executor.cpu().num_core(), args.output.clone());
|
||||||
executor.execute();
|
executor.execute()?;
|
||||||
dump_memory(executor, &args)?;
|
dump_memory(executor, &args)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
|
use anyhow::{Result, bail};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
|
|||||||
send_recv: SendRecv,
|
send_recv: SendRecv,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DeadlockInfo {
|
||||||
|
cycle: String,
|
||||||
|
states: String,
|
||||||
|
}
|
||||||
|
|
||||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||||
let mut tot_instructions = 0;
|
let mut tot_instructions = 0;
|
||||||
let mut progress = 0;
|
let mut progress = 0;
|
||||||
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute<'b>(&'b mut self)
|
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||||
where
|
where
|
||||||
'a: 'b,
|
'a: 'b,
|
||||||
{
|
{
|
||||||
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||||
print_status(cores_instructions);
|
print_status(cores_instructions);
|
||||||
check_cycle(cpu, cores_instructions, send_recv);
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
|
}
|
||||||
now = SystemTime::now();
|
now = SystemTime::now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
print_status(cores_instructions);
|
print_status(cores_instructions);
|
||||||
|
|
||||||
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cores_instructions
|
||||||
|
.iter()
|
||||||
|
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||||
|
{
|
||||||
|
bail!("Execution stalled with unfinished instructions");
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "profile_time")]
|
#[cfg(feature = "profile_time")]
|
||||||
TRACER.lock().unwrap().report();
|
TRACER.lock().unwrap().report();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cpu(&self) -> &CPU<'a> {
|
pub fn cpu(&self) -> &CPU<'a> {
|
||||||
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
|
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
enum CoreState {
|
enum CoreState {
|
||||||
SendingTo(i32),
|
SendingTo(i32, i32),
|
||||||
ReceivingFrom(i32),
|
ReceivingFrom(i32, i32),
|
||||||
Working,
|
Working,
|
||||||
Halted,
|
Halted,
|
||||||
}
|
}
|
||||||
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
let (this_core, target_core) = data.get_core_immcore();
|
let (this_core, target_core) = data.get_core_immcore();
|
||||||
|
|
||||||
if isa_recv(functor_address) {
|
if isa_recv(functor_address) {
|
||||||
states.insert(this_core, CoreState::ReceivingFrom(target_core));
|
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||||
} else if isa_send(functor_address) {
|
} else if isa_send(functor_address) {
|
||||||
states.insert(this_core, CoreState::SendingTo(target_core));
|
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||||
} else {
|
} else {
|
||||||
states.insert(this_core, CoreState::Working);
|
states.insert(this_core, CoreState::Working);
|
||||||
}
|
}
|
||||||
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
|
|
||||||
for (&core_id, state) in states.iter() {
|
for (&core_id, state) in states.iter() {
|
||||||
match state {
|
match state {
|
||||||
CoreState::SendingTo(target_core) => {
|
CoreState::SendingTo(target_core, size) => {
|
||||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
if target_state != &CoreState::ReceivingFrom(core_id) {
|
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||||
wait_for.insert(core_id, *target_core);
|
wait_for.insert(core_id, *target_core);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CoreState::ReceivingFrom(target_core) => {
|
CoreState::ReceivingFrom(target_core, size) => {
|
||||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
if target_state != &CoreState::SendingTo(core_id) {
|
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||||
wait_for.insert(core_id, *target_core);
|
wait_for.insert(core_id, *target_core);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(" -> ");
|
.join(" -> ");
|
||||||
|
|
||||||
|
let cycle = cycle
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.chain(std::iter::once(waiting_for))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||||
|
let states_msg = cycle
|
||||||
|
.iter()
|
||||||
|
.filter_map(|core| {
|
||||||
|
states.get(core).map(|state| match state {
|
||||||
|
CoreState::SendingTo(target, size) => {
|
||||||
|
format!("core {} send {}B -> {}", core, size, target)
|
||||||
|
}
|
||||||
|
CoreState::ReceivingFrom(source, size) => {
|
||||||
|
format!("core {} recv {}B <- {}", core, size, source)
|
||||||
|
}
|
||||||
|
CoreState::Working => format!("core {} working", core),
|
||||||
|
CoreState::Halted => format!("core {} halted", core),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
|
||||||
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
|
return Some(DeadlockInfo {
|
||||||
// bail!("Deadlock detected: {}", cycle_msg);
|
cycle: cycle_msg,
|
||||||
break; // Stop tracing
|
states: states_msg,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hit a known branch that didn't result in a cycle
|
// Hit a known branch that didn't result in a cycle
|
||||||
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
current_core = waiting_for;
|
current_core = waiting_for;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_wait_sync<'a, 'b, 'c>(
|
fn handle_wait_sync<'a, 'b, 'c>(
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
|||||||
llvm::cl::init(EmitPimCodegen),
|
llvm::cl::init(EmitPimCodegen),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||||
|
"pim-merge-scheduler",
|
||||||
|
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||||
|
llvm::cl::init(MergeSchedulerPeft),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
pimOnlyCodegen("pim-only-codegen",
|
pimOnlyCodegen("pim-only-codegen",
|
||||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||||
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
|||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||||
llvm::cl::init(-1));
|
llvm::cl::init(-1));
|
||||||
|
|
||||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||||
"dcp-critical-window-size",
|
"dcp-critical-window-size",
|
||||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||||
llvm::cl::init(4000));
|
llvm::cl::init(4000));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
|
|||||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||||
|
|
||||||
|
void verifyExplicitPimCoreCount() {
|
||||||
|
if (!hasExplicitPimCoreCount())
|
||||||
|
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||||
|
if (coresCount.getValue() <= 0)
|
||||||
|
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -20,8 +20,14 @@ typedef enum {
|
|||||||
EmitPimCodegen = 3
|
EmitPimCodegen = 3
|
||||||
} PimEmissionTargetType;
|
} PimEmissionTargetType;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
MergeSchedulerPeft = 0,
|
||||||
|
MergeSchedulerDcp = 1,
|
||||||
|
} PimMergeSchedulerType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
|
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
|
|||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount();
|
||||||
|
void verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
// This option, by default set to false, will ignore an error when resolving a
|
// This option, by default set to false, will ignore an error when resolving a
|
||||||
// specific tiles of the operands of a concat. This specific case is when the
|
// specific tiles of the operands of a concat. This specific case is when the
|
||||||
// wanted tile is generated by two separate operands of the concat. If this is
|
// wanted tile is generated by two separate operands of the concat. If this is
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
PassManager& pm,
|
PassManager& pm,
|
||||||
EmissionTargetType& emissionTarget,
|
EmissionTargetType& emissionTarget,
|
||||||
std::string outputNameNoExt) {
|
std::string outputNameNoExt) {
|
||||||
|
verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
if (pimOnlyCodegen) {
|
if (pimOnlyCodegen) {
|
||||||
// Skip all the lowering passes and directly generate code for PIM.
|
// Skip all the lowering passes and directly generate code for PIM.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
@@ -48,6 +49,13 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
|||||||
for (Value result : user->getResults())
|
for (Value result : user->getResults())
|
||||||
pendingValues.push_back(result);
|
pendingValues.push_back(result);
|
||||||
|
|
||||||
|
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||||
|
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||||
|
if (initArg == value)
|
||||||
|
pendingValues.push_back(forOp.getResult(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||||
for (OpResult result : user->getResults()) {
|
for (OpResult result : user->getResults()) {
|
||||||
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||||
|
|||||||
@@ -2,64 +2,27 @@
|
|||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "../Scheduling/MergeSchedule.hpp"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
|
||||||
// spat.compute_batch.
|
|
||||||
struct ComputeInstance {
|
|
||||||
mlir::Operation* op = nullptr;
|
|
||||||
uint32_t laneStart = 0;
|
|
||||||
uint32_t laneCount = 1;
|
|
||||||
|
|
||||||
bool operator==(const ComputeInstance& other) const {
|
|
||||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
|
||||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
|
||||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
using DCPAnalysisResult = MergeScheduleResult;
|
||||||
|
|
||||||
struct DCPAnalysis {
|
struct DCPAnalysis {
|
||||||
private:
|
private:
|
||||||
DCPAnalysisResult result;
|
DCPAnalysisResult result;
|
||||||
mlir::Operation* entryOp;
|
mlir::Operation *entryOp;
|
||||||
DCPAnalysisResult run();
|
DCPAnalysisResult run();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DCPAnalysis(mlir::Operation* op)
|
DCPAnalysis(mlir::Operation *op)
|
||||||
: entryOp(op) {
|
: entryOp(op) {
|
||||||
result = run();
|
result = run();
|
||||||
}
|
}
|
||||||
DCPAnalysisResult& getResult() { return result; }
|
DCPAnalysisResult &getResult() { return result; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
namespace llvm {
|
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
|
||||||
template <>
|
|
||||||
struct DenseMapInfo<ComputeInstance> {
|
|
||||||
static ComputeInstance getEmptyKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static ComputeInstance getTombstoneKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
|
||||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
|
||||||
};
|
|
||||||
} // namespace llvm
|
|
||||||
|
|||||||
@@ -36,12 +36,13 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "RegularOpCompaction.hpp"
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -344,11 +345,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||||
++opIts[groupIndex];
|
++opIts[groupIndex];
|
||||||
}
|
}
|
||||||
llvm::stable_sort(entries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
@@ -356,8 +352,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(group.size());
|
sourceCoreIds.reserve(group.size());
|
||||||
targetCoreIds.reserve(group.size());
|
targetCoreIds.reserve(group.size());
|
||||||
for (const BatchReceiveEntry& entry : entries) {
|
for (const BatchReceiveEntry& entry : entries) {
|
||||||
(void) entry;
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
@@ -384,11 +379,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||||
++opIts[groupIndex];
|
++opIts[groupIndex];
|
||||||
}
|
}
|
||||||
llvm::stable_sort(entries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
@@ -396,8 +386,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(group.size());
|
sourceCoreIds.reserve(group.size());
|
||||||
targetCoreIds.reserve(group.size());
|
targetCoreIds.reserve(group.size());
|
||||||
for (const BatchSendEntry& entry : entries) {
|
for (const BatchSendEntry& entry : entries) {
|
||||||
(void) entry;
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
@@ -983,7 +972,7 @@ public:
|
|||||||
|
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
Location loc = func.getLoc();
|
Location loc = func.getLoc();
|
||||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
spatial::MergeScheduleResult& analysisResult = getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||||
DenseSet<Operation*> toEraseSet;
|
DenseSet<Operation*> toEraseSet;
|
||||||
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
|
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
|
||||||
toEraseSet.insert(instance.op);
|
toEraseSet.insert(instance.op);
|
||||||
@@ -994,6 +983,7 @@ public:
|
|||||||
size_t cpu = 0;
|
size_t cpu = 0;
|
||||||
size_t slot = 0;
|
size_t slot = 0;
|
||||||
size_t order = 0;
|
size_t order = 0;
|
||||||
|
size_t executionOrder = 0;
|
||||||
};
|
};
|
||||||
struct ChannelInfo {
|
struct ChannelInfo {
|
||||||
int64_t channelId = -1;
|
int64_t channelId = -1;
|
||||||
@@ -1117,6 +1107,10 @@ public:
|
|||||||
return lhs.slot < rhs.slot;
|
return lhs.slot < rhs.slot;
|
||||||
return lhs.order < rhs.order;
|
return lhs.order < rhs.order;
|
||||||
});
|
});
|
||||||
|
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
|
||||||
|
task.executionOrder = executionOrder;
|
||||||
|
taskByKey[task.key].executionOrder = executionOrder;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<bool(Operation*)> isInternalInputOp = [&](Operation* op) {
|
std::function<bool(Operation*)> isInternalInputOp = [&](Operation* op) {
|
||||||
@@ -1196,7 +1190,8 @@ public:
|
|||||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
if (perResultChannels.empty())
|
if (perResultChannels.empty())
|
||||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
||||||
perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0});
|
perResultChannels[producerRef->resultIndex].push_back(
|
||||||
|
{info, task.key, inputIndex, task.executionOrder, 0});
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -1271,6 +1266,18 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||||
|
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||||
for (auto& taskSends : remoteSendsByTask) {
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
for (const auto& sendInfos : taskSends.second) {
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
@@ -1601,6 +1608,7 @@ public:
|
|||||||
for (Operation* op : orderedUsersToMove)
|
for (Operation* op : orderedUsersToMove)
|
||||||
op->moveBefore(returnOp);
|
op->moveBefore(returnOp);
|
||||||
|
|
||||||
|
orderBilateralChannelOps(func);
|
||||||
rebatchEquivalentComputes(func, nextChannelId);
|
rebatchEquivalentComputes(func, nextChannelId);
|
||||||
compactScalarChannelRuns(func, nextChannelId);
|
compactScalarChannelRuns(func, nextChannelId);
|
||||||
compactBatchChannelRuns(func);
|
compactBatchChannelRuns(func);
|
||||||
@@ -1632,7 +1640,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::pair<SpatCompute, ComputeValueResults>
|
std::pair<SpatCompute, ComputeValueResults>
|
||||||
createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const DCPAnalysisResult& analysisResult) {
|
createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const spatial::MergeScheduleResult& analysisResult) {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
auto loc = func.getLoc();
|
auto loc = func.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
@@ -1712,7 +1720,7 @@ private:
|
|||||||
uint32_t firstLane,
|
uint32_t firstLane,
|
||||||
uint32_t laneCount,
|
uint32_t laneCount,
|
||||||
size_t currentCpu,
|
size_t currentCpu,
|
||||||
const DCPAnalysisResult& analysisResult,
|
const spatial::MergeScheduleResult& analysisResult,
|
||||||
std::optional<uint64_t> rebatchPhase = std::nullopt) {
|
std::optional<uint64_t> rebatchPhase = std::nullopt) {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
auto loc = func.getLoc();
|
auto loc = func.getLoc();
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "RegularOpCompaction.hpp"
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -166,6 +167,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
|||||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isForwardedChannelPayload(Value value, Block& block) {
|
||||||
|
Operation* op = value.getDefiningOp();
|
||||||
|
if (!op || op->getBlock() != &block)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
|
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
|
||||||
|
|
||||||
|
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||||
RegularChunk chunk;
|
RegularChunk chunk;
|
||||||
chunk.startOp = startOp.getOperation();
|
chunk.startOp = startOp.getOperation();
|
||||||
@@ -319,6 +331,76 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
|
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||||
|
if (!coreIdAttr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
|
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||||
|
|
||||||
|
for (Operation& op : block) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||||
|
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||||
|
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* firstMatchingSend = nullptr;
|
||||||
|
for (Operation* previous = receiveOp->getPrevNode(); previous; previous = previous->getPrevNode()) {
|
||||||
|
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(previous);
|
||||||
|
if (!sendOp || sendOp.getSourceCoreId() != static_cast<uint32_t>(coreId)
|
||||||
|
|| sendOp.getTargetCoreId() != receiveOp.getSourceCoreId()
|
||||||
|
|| !isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
firstMatchingSend = sendOp.getOperation();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (firstMatchingSend)
|
||||||
|
moves.push_back({receiveOp, firstMatchingSend});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [receiveOp, insertionPoint] : moves)
|
||||||
|
receiveOp->moveBefore(insertionPoint);
|
||||||
|
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
|
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||||
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
|
auto runIt = it;
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||||
|
if (!current || current.getOutput().getType() != outputType
|
||||||
|
|| current.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
run.push_back(current);
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1) {
|
||||||
|
SmallVector<spatial::SpatChannelReceiveOp> sorted(run);
|
||||||
|
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||||
|
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||||
|
});
|
||||||
|
Block::iterator insertIt = runIt;
|
||||||
|
for (auto op : sorted)
|
||||||
|
op->moveBefore(&block, insertIt);
|
||||||
|
}
|
||||||
|
|
||||||
|
it = runIt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
@@ -369,8 +451,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(sortedEntries.size());
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
for (ReceiveEntry& entry : sortedEntries) {
|
for (ReceiveEntry& entry : sortedEntries) {
|
||||||
(void) entry;
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
}
|
}
|
||||||
@@ -451,8 +532,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
inputs.reserve(sortedEntries.size());
|
inputs.reserve(sortedEntries.size());
|
||||||
for (SendEntry& entry : sortedEntries) {
|
for (SendEntry& entry : sortedEntries) {
|
||||||
(void) entry;
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
inputs.push_back(entry.op.getInput());
|
inputs.push_back(entry.op.getInput());
|
||||||
@@ -636,7 +716,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
compactRegularChunkRun(rewriter, run);
|
compactRegularChunkRun(rewriter, run);
|
||||||
it = runIt;
|
it = block.begin();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||||
|
|||||||
@@ -0,0 +1,272 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
|
#include <queue>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget() {
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
return static_cast<size_t>(coresCount.getValue());
|
||||||
|
return std::numeric_limits<size_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
Weight getComputeBodyWeight(Region &body) {
|
||||||
|
constexpr Weight kOperationWeight = 100;
|
||||||
|
Weight numOperations = 0;
|
||||||
|
for (auto &block : body)
|
||||||
|
for ([[maybe_unused]] auto &op : block)
|
||||||
|
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||||
|
return checkedMultiply(numOperations, kOperationWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
for (auto &block : body)
|
||||||
|
for (auto &op : block)
|
||||||
|
if (isa<SpatVMMOp>(op))
|
||||||
|
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||||
|
return crossbarUsage;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||||
|
if (producerOp->getNumResults() == 0)
|
||||||
|
return false;
|
||||||
|
for (Value result : producerOp->getResults()) {
|
||||||
|
if (result.use_empty())
|
||||||
|
return false;
|
||||||
|
for (Operation *user : result.getUsers()) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||||
|
if (!llvm::is_contained(compute.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||||
|
if (!llvm::is_contained(batch.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||||
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
|
for (const ComputeGraphEdge &edge : edges) {
|
||||||
|
if (edge.source == edge.target)
|
||||||
|
continue;
|
||||||
|
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||||
|
if (!inserted.second)
|
||||||
|
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||||
|
aggregatedEdges.reserve(edgeWeights.size());
|
||||||
|
for (const auto &[key, weight] : edgeWeights)
|
||||||
|
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||||
|
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
|
||||||
|
if (lhs.source != rhs.source)
|
||||||
|
return lhs.source < rhs.source;
|
||||||
|
return lhs.target < rhs.target;
|
||||||
|
});
|
||||||
|
return aggregatedEdges;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||||
|
assert(laneCount > 0 && "laneCount must be positive");
|
||||||
|
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
|
||||||
|
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||||
|
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||||
|
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||||
|
|
||||||
|
size_t chunkIndex = 0;
|
||||||
|
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||||
|
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||||
|
else
|
||||||
|
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||||
|
return getBatchChunkForIndex(batch, chunkIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return getSpatComputeWeight(spatCompute);
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return getSpatComputeCrossbarUsage(spatCompute);
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
||||||
|
static_cast<CrossbarUsage>(instance.laneCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> inputs;
|
||||||
|
inputs.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
inputs.push_back(batch.getInputs()[lane]);
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(spatCompute.getWeights().begin(), spatCompute.getWeights().end());
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> weights;
|
||||||
|
weights.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
weights.push_back(batch.getWeights()[lane]);
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||||
|
Operation *op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
||||||
|
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
||||||
|
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||||
|
ComputeGraph graph;
|
||||||
|
|
||||||
|
for (Region ®ion : entryOp->getRegions()) {
|
||||||
|
for (Block &block : region) {
|
||||||
|
for (Operation &op : block) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||||
|
continue;
|
||||||
|
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||||
|
size_t index = graph.nodes.size();
|
||||||
|
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
|
graph.instanceToIndex[instance] = index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||||
|
continue;
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
|
||||||
|
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||||
|
size_t index = graph.nodes.size();
|
||||||
|
graph.nodes.push_back(
|
||||||
|
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
|
graph.instanceToIndex[instance] = index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||||
|
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||||
|
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||||
|
auto producerInstance = getComputeProducerInstance(input);
|
||||||
|
if (!producerInstance)
|
||||||
|
continue;
|
||||||
|
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||||
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
|
continue;
|
||||||
|
rawEdges.push_back(
|
||||||
|
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||||
|
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||||
|
graph.successors.assign(graph.nodes.size(), {});
|
||||||
|
graph.predecessors.assign(graph.nodes.size(), {});
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||||
|
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool verifyAcyclic(const ComputeGraph &graph) {
|
||||||
|
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||||
|
std::queue<size_t> readyNodes;
|
||||||
|
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||||
|
remainingParents[node] = graph.predecessors[node].size();
|
||||||
|
if (remainingParents[node] == 0)
|
||||||
|
readyNodes.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t visited = 0;
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t node = readyNodes.front();
|
||||||
|
readyNodes.pop();
|
||||||
|
++visited;
|
||||||
|
for (const auto &[child, weight] : graph.successors[node]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
|
if (--remainingParents[child] == 0)
|
||||||
|
readyNodes.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return visited == graph.nodes.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../DCPGraph/Utils.hpp"
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ComputeGraphNode {
|
||||||
|
ComputeInstance instance;
|
||||||
|
Weight weight = 0;
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
size_t originalOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ComputeGraphEdge {
|
||||||
|
size_t source = 0;
|
||||||
|
size_t target = 0;
|
||||||
|
Weight transferCost = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ComputeGraph {
|
||||||
|
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||||
|
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||||
|
bool verifyAcyclic(const ComputeGraph &graph);
|
||||||
|
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||||
|
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||||
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ComputeInstance {
|
||||||
|
mlir::Operation *op = nullptr;
|
||||||
|
uint32_t laneStart = 0;
|
||||||
|
uint32_t laneCount = 1;
|
||||||
|
|
||||||
|
bool operator==(const ComputeInstance &other) const {
|
||||||
|
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
template <>
|
||||||
|
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||||
|
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||||
|
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||||
|
}
|
||||||
|
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||||
|
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||||
|
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||||
|
}
|
||||||
|
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||||
|
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace llvm
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
#include "../DCPGraph/Graph.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context) {
|
||||||
|
llvm::SmallVector<Weight> nodeWeights;
|
||||||
|
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||||
|
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||||
|
llvm::SmallVector<IndexedEdge> edges;
|
||||||
|
nodeWeights.reserve(graph.nodes.size());
|
||||||
|
nodeCrossbarUsage.reserve(graph.nodes.size());
|
||||||
|
nodeOrderKeys.reserve(graph.nodes.size());
|
||||||
|
edges.reserve(graph.edges.size());
|
||||||
|
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes) {
|
||||||
|
nodeWeights.push_back(node.weight);
|
||||||
|
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||||
|
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||||
|
}
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
edges.push_back(
|
||||||
|
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
|
graphDCP.setContext(context);
|
||||||
|
graphDCP.runDcp();
|
||||||
|
|
||||||
|
MergeScheduleResult result;
|
||||||
|
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes)
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
|
||||||
|
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (const auto &[slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||||
|
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||||
|
result.computeToCpuMap[instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[instance] = slot;
|
||||||
|
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||||
|
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||||
|
result.isLastComputeOfCpu.insert(lastInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct MergeScheduleResult {
|
||||||
|
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||||
|
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||||
|
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+133
@@ -0,0 +1,133 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
#include "MergeSchedulingAnalysis.hpp"
|
||||||
|
#include "PeftScheduler.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
MergeSchedulerKind getSchedulerKind() {
|
||||||
|
switch (pimMergeScheduler.getValue()) {
|
||||||
|
case MergeSchedulerPeft:
|
||||||
|
return MergeSchedulerKind::Peft;
|
||||||
|
case MergeSchedulerDcp:
|
||||||
|
return MergeSchedulerKind::Dcp;
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||||
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
|
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||||
|
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
|
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
|
||||||
|
if (!result.computeToCpuMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
|
||||||
|
if (!result.computeToCpuSlotMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
|
||||||
|
if (!result.computeToAestMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing start time");
|
||||||
|
|
||||||
|
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
|
||||||
|
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &entry : tasksByCpu) {
|
||||||
|
auto &scheduledTasks = entry.second;
|
||||||
|
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||||
|
if (lhs.first != rhs.first)
|
||||||
|
return lhs.first < rhs.first;
|
||||||
|
return lhs.second < rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
CrossbarUsage usedCrossbars = 0;
|
||||||
|
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||||
|
if (scheduledTasks[slot].first != slot)
|
||||||
|
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||||
|
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||||
|
if (usedCrossbars > crossbarCapacity)
|
||||||
|
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||||
|
}
|
||||||
|
|
||||||
|
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
|
||||||
|
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
|
||||||
|
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
|
||||||
|
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
|
||||||
|
if (!result.isLastComputeOfCpu.count(expectedLast))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||||
|
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||||
|
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||||
|
const size_t targetCpu = result.computeToCpuMap.lookup(target);
|
||||||
|
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
|
||||||
|
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
|
||||||
|
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
|
||||||
|
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
|
||||||
|
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||||
|
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||||
|
|
||||||
|
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||||
|
if (sourceCpu != targetCpu)
|
||||||
|
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||||
|
if (targetStart < earliestTargetStart) {
|
||||||
|
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||||
|
graph.nodes[edge.source].originalOrder,
|
||||||
|
graph.nodes[edge.target].originalOrder)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||||
|
: entryOp(op) {
|
||||||
|
result = run();
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||||
|
verifyExplicitPimCoreCount();
|
||||||
|
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||||
|
if (!verifyAcyclic(graph))
|
||||||
|
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
|
||||||
|
|
||||||
|
MergeSchedulingOptions options;
|
||||||
|
options.kind = getSchedulerKind();
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||||
|
|
||||||
|
MergeScheduleResult schedule;
|
||||||
|
if (options.kind == MergeSchedulerKind::Peft) {
|
||||||
|
schedule = runPeftScheduler(
|
||||||
|
graph,
|
||||||
|
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||||
|
entryOp->getContext()});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
schedule = DCPAnalysis(entryOp).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.kind == MergeSchedulerKind::Peft)
|
||||||
|
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||||
|
return schedule;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+36
@@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
enum class MergeSchedulerKind {
|
||||||
|
Dcp,
|
||||||
|
Peft,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MergeSchedulingOptions {
|
||||||
|
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
|
||||||
|
size_t processorCount = 0;
|
||||||
|
bool allowDcpFallbackForAutoCoreCount = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MergeSchedulingAnalysis {
|
||||||
|
public:
|
||||||
|
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||||
|
MergeScheduleResult &getResult() { return result; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::Operation *entryOp = nullptr;
|
||||||
|
MergeScheduleResult result;
|
||||||
|
|
||||||
|
MergeScheduleResult run();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
#include "mlir/IR/Threading.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "PeftScheduler.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ScheduledTask {
|
||||||
|
size_t processor = std::numeric_limits<size_t>::max();
|
||||||
|
Time startTime = 0;
|
||||||
|
Time endTime = 0;
|
||||||
|
size_t slot = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph &graph) {
|
||||||
|
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||||
|
std::queue<size_t> readySinks;
|
||||||
|
std::vector<std::vector<size_t>> reverseLevels;
|
||||||
|
|
||||||
|
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||||
|
remainingSuccessors[node] = graph.successors[node].size();
|
||||||
|
if (remainingSuccessors[node] == 0)
|
||||||
|
readySinks.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t levelizedCount = 0;
|
||||||
|
while (!readySinks.empty()) {
|
||||||
|
size_t levelSize = readySinks.size();
|
||||||
|
std::vector<size_t> levelNodes;
|
||||||
|
levelNodes.reserve(levelSize);
|
||||||
|
for (size_t i = 0; i < levelSize; ++i) {
|
||||||
|
size_t node = readySinks.front();
|
||||||
|
readySinks.pop();
|
||||||
|
levelNodes.push_back(node);
|
||||||
|
++levelizedCount;
|
||||||
|
for (const auto &[pred, weight] : graph.predecessors[node]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||||
|
if (--remainingSuccessors[pred] == 0)
|
||||||
|
readySinks.push(pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reverseLevels.push_back(std::move(levelNodes));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (levelizedCount != graph.nodes.size())
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
|
||||||
|
|
||||||
|
return reverseLevels;
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||||
|
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
||||||
|
if (nodeCount == 0 || processorCount == 0)
|
||||||
|
return;
|
||||||
|
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||||
|
size_t rowBytes = processorCount * sizeof(Time);
|
||||||
|
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||||
|
size_t totalBytes = nodeCount * rowBytes;
|
||||||
|
if (totalBytes > kMaxOctTableBytes) {
|
||||||
|
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
|
||||||
|
totalBytes / (1024 * 1024))
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options) {
|
||||||
|
const size_t nodeCount = graph.nodes.size();
|
||||||
|
const size_t processorCount = options.processorCount;
|
||||||
|
if (processorCount == 0)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||||
|
|
||||||
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
|
|
||||||
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
|
for (const std::vector<size_t> &levelNodes : reverseLevels) {
|
||||||
|
auto computeNodeOct = [&](size_t levelIndex) {
|
||||||
|
size_t task = levelNodes[levelIndex];
|
||||||
|
std::vector<Time> maxVals(processorCount, 0);
|
||||||
|
|
||||||
|
for (const auto &[succ, comm] : graph.successors[task]) {
|
||||||
|
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], graph.nodes[succ].weight);
|
||||||
|
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||||
|
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Time minForPreds = std::numeric_limits<Time>::max();
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
oct[task * processorCount + processor] = maxVals[processor];
|
||||||
|
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], graph.nodes[task].weight));
|
||||||
|
}
|
||||||
|
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options.context != nullptr)
|
||||||
|
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
|
||||||
|
else
|
||||||
|
for (size_t i = 0; i < levelNodes.size(); ++i)
|
||||||
|
computeNodeOct(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RankEntry {
|
||||||
|
long double rank = 0.0L;
|
||||||
|
size_t node = 0;
|
||||||
|
size_t originalOrder = 0;
|
||||||
|
};
|
||||||
|
std::vector<RankEntry> ranks(nodeCount);
|
||||||
|
auto computeRank = [&](size_t node) {
|
||||||
|
long double rank = 0.0L;
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor)
|
||||||
|
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||||
|
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||||
|
};
|
||||||
|
if (options.context != nullptr)
|
||||||
|
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||||
|
else
|
||||||
|
for (size_t node = 0; node < nodeCount; ++node)
|
||||||
|
computeRank(node);
|
||||||
|
|
||||||
|
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||||
|
const RankEntry &lhsRank = ranks[lhs];
|
||||||
|
const RankEntry &rhsRank = ranks[rhs];
|
||||||
|
if (lhsRank.rank != rhsRank.rank)
|
||||||
|
return lhsRank.rank < rhsRank.rank;
|
||||||
|
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||||
|
return lhsRank.originalOrder > rhsRank.originalOrder;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> remainingParents(nodeCount, 0);
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
|
||||||
|
for (size_t node = 0; node < nodeCount; ++node) {
|
||||||
|
remainingParents[node] = graph.predecessors[node].size();
|
||||||
|
if (remainingParents[node] == 0)
|
||||||
|
readyQueue.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> scheduled(nodeCount, false);
|
||||||
|
std::vector<Time> processorAvailable(processorCount, 0);
|
||||||
|
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||||
|
std::vector<ScheduledTask> schedules(nodeCount);
|
||||||
|
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||||
|
|
||||||
|
size_t scheduledCount = 0;
|
||||||
|
while (!readyQueue.empty()) {
|
||||||
|
size_t task = readyQueue.top();
|
||||||
|
readyQueue.pop();
|
||||||
|
if (scheduled[task])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t bestProcessor = std::numeric_limits<size_t>::max();
|
||||||
|
Time bestEst = 0;
|
||||||
|
Time bestEft = 0;
|
||||||
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
if (graph.nodes[task].crossbarUsage != 0 &&
|
||||||
|
addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||||
|
crossbarRejected = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Time dataReady = 0;
|
||||||
|
for (const auto &[pred, comm] : graph.predecessors[task]) {
|
||||||
|
const ScheduledTask &predSchedule = schedules[pred];
|
||||||
|
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||||
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
|
}
|
||||||
|
|
||||||
|
Time est = std::max(processorAvailable[processor], dataReady);
|
||||||
|
Time eft = addOrMax(est, graph.nodes[task].weight);
|
||||||
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft) ||
|
||||||
|
(oeft == bestOeft && eft == bestEft && est < bestEst) ||
|
||||||
|
(oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||||
|
bestProcessor = processor;
|
||||||
|
bestEst = est;
|
||||||
|
bestEft = eft;
|
||||||
|
bestOeft = oeft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bestProcessor == std::numeric_limits<size_t>::max()) {
|
||||||
|
if (crossbarRejected) {
|
||||||
|
std::string message =
|
||||||
|
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
|
||||||
|
graph.nodes[task].originalOrder,
|
||||||
|
options.crossbarCapacity)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
|
||||||
|
graph.nodes[task].originalOrder,
|
||||||
|
processorCount)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
|
||||||
|
schedules[task] = {bestProcessor, bestEst, bestEft, tasksByProcessor[bestProcessor].size()};
|
||||||
|
scheduled[task] = true;
|
||||||
|
++scheduledCount;
|
||||||
|
processorAvailable[bestProcessor] = bestEft;
|
||||||
|
processorCrossbars[bestProcessor] =
|
||||||
|
addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||||
|
tasksByProcessor[bestProcessor].push_back(task);
|
||||||
|
|
||||||
|
for (const auto &[child, weight] : graph.successors[task]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
|
if (--remainingParents[child] == 0)
|
||||||
|
readyQueue.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scheduledCount != nodeCount)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||||
|
|
||||||
|
MergeScheduleResult result;
|
||||||
|
result.dominanceOrderCompute.reserve(nodeCount);
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes)
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
for (size_t task : tasksByProcessor[processor]) {
|
||||||
|
const ComputeInstance instance = graph.nodes[task].instance;
|
||||||
|
result.computeToCpuMap[instance] = processor;
|
||||||
|
result.computeToCpuSlotMap[instance] = schedules[task].slot;
|
||||||
|
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||||
|
}
|
||||||
|
if (!tasksByProcessor[processor].empty()) {
|
||||||
|
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
|
||||||
|
result.cpuToLastComputeMap[processor] = lastInstance;
|
||||||
|
result.isLastComputeOfCpu.insert(lastInstance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct PeftScheduleOptions {
|
||||||
|
size_t processorCount = 0;
|
||||||
|
CrossbarUsage crossbarCapacity = 0;
|
||||||
|
mlir::MLIRContext *context = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -41,7 +41,8 @@ def _format_command(cmd):
|
|||||||
|
|
||||||
|
|
||||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||||
crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
|
crossbar_size, crossbar_count, core_count=None, pim_merge_scheduler="peft",
|
||||||
|
cwd=None, verbose=False, reporter=None, timeout_sec=None):
|
||||||
# Define the arguments, with the possibility to set crossbar size and count
|
# Define the arguments, with the possibility to set crossbar size and count
|
||||||
args = [
|
args = [
|
||||||
network_path,
|
network_path,
|
||||||
@@ -51,6 +52,7 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
|||||||
"--EmitPimCodegen",
|
"--EmitPimCodegen",
|
||||||
f"--crossbar-size={crossbar_size}",
|
f"--crossbar-size={crossbar_size}",
|
||||||
f"--crossbar-count={crossbar_count}",
|
f"--crossbar-count={crossbar_count}",
|
||||||
|
f"--pim-merge-scheduler={pim_merge_scheduler}",
|
||||||
]
|
]
|
||||||
if core_count is not None:
|
if core_count is not None:
|
||||||
args.append(f"--core-count={core_count}")
|
args.append(f"--core-count={core_count}")
|
||||||
@@ -69,6 +71,7 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
|||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
|
timeout_sec=timeout_sec,
|
||||||
)
|
)
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import os
|
|||||||
import pty
|
import pty
|
||||||
import selectors
|
import selectors
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
MAX_ERROR_OUTPUT_BYTES = 8192
|
MAX_ERROR_OUTPUT_BYTES = 8192
|
||||||
|
|
||||||
@@ -16,16 +17,26 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True):
|
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True, timeout_sec=None):
|
||||||
selector = selectors.DefaultSelector()
|
selector = selectors.DefaultSelector()
|
||||||
recent_output = bytearray()
|
recent_output = bytearray()
|
||||||
captured_output = bytearray()
|
captured_output = bytearray()
|
||||||
|
deadline = None if timeout_sec is None else time.monotonic() + timeout_sec
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selector.register(fd, selectors.EVENT_READ)
|
selector.register(fd, selectors.EVENT_READ)
|
||||||
|
|
||||||
while selector.get_map():
|
while selector.get_map():
|
||||||
for key, _ in selector.select():
|
select_timeout = None
|
||||||
|
if deadline is not None:
|
||||||
|
remaining = deadline - time.monotonic()
|
||||||
|
if remaining <= 0:
|
||||||
|
process.kill()
|
||||||
|
process.wait()
|
||||||
|
raise subprocess.TimeoutExpired(process.args, timeout_sec, output=bytes(captured_output))
|
||||||
|
select_timeout = min(1.0, remaining)
|
||||||
|
|
||||||
|
for key, _ in selector.select(select_timeout):
|
||||||
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
|
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
|
||||||
if not data:
|
if not data:
|
||||||
selector.unregister(key.fileobj)
|
selector.unregister(key.fileobj)
|
||||||
@@ -53,7 +64,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
|
|||||||
return bytes(captured_output)
|
return bytes(captured_output)
|
||||||
|
|
||||||
|
|
||||||
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
|
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False, timeout_sec=None):
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
if capture_output:
|
if capture_output:
|
||||||
completed = subprocess.run(
|
completed = subprocess.run(
|
||||||
@@ -62,9 +73,10 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
|||||||
check=True,
|
check=True,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=timeout_sec,
|
||||||
)
|
)
|
||||||
return completed.stdout.decode("utf-8", errors="replace")
|
return completed.stdout.decode("utf-8", errors="replace")
|
||||||
subprocess.run(cmd, cwd=cwd, check=True)
|
subprocess.run(cmd, cwd=cwd, check=True, timeout=timeout_sec)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
stream_output = bool(getattr(reporter, "verbose", False))
|
stream_output = bool(getattr(reporter, "verbose", False))
|
||||||
@@ -74,6 +86,7 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
|||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=timeout_sec,
|
||||||
)
|
)
|
||||||
if completed.returncode != 0:
|
if completed.returncode != 0:
|
||||||
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
|
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
|
||||||
@@ -89,7 +102,7 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
|||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
)
|
)
|
||||||
assert process.stdout is not None
|
assert process.stdout is not None
|
||||||
output = _stream_output(process.stdout.fileno(), process, reporter)
|
output = _stream_output(process.stdout.fileno(), process, reporter, timeout_sec=timeout_sec)
|
||||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -102,5 +115,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
|||||||
finally:
|
finally:
|
||||||
os.close(slave_fd)
|
os.close(slave_fd)
|
||||||
|
|
||||||
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True, timeout_sec=timeout_sec)
|
||||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|||||||
@@ -64,7 +64,11 @@ def main():
|
|||||||
ap.add_argument("--crossbar-size", type=int, default=64)
|
ap.add_argument("--crossbar-size", type=int, default=64)
|
||||||
ap.add_argument("--crossbar-count", type=int, default=8)
|
ap.add_argument("--crossbar-count", type=int, default=8)
|
||||||
ap.add_argument("--core-count", type=int, default=None,
|
ap.add_argument("--core-count", type=int, default=None,
|
||||||
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
|
help="Core count to pass to Raptor. Required for PIM validation.")
|
||||||
|
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
|
||||||
|
help="Scheduler used by the Spatial merge-compute-nodes pass.")
|
||||||
|
ap.add_argument("--command-timeout-seconds", type=float, default=60.0,
|
||||||
|
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
|
||||||
ap.add_argument("--clean", action="store_true",
|
ap.add_argument("--clean", action="store_true",
|
||||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||||
ap.add_argument("--verbose", action="store_true",
|
ap.add_argument("--verbose", action="store_true",
|
||||||
@@ -98,6 +102,8 @@ def main():
|
|||||||
missing_args.append("--raptor-path")
|
missing_args.append("--raptor-path")
|
||||||
if not a.onnx_include_dir:
|
if not a.onnx_include_dir:
|
||||||
missing_args.append("--onnx-include-dir")
|
missing_args.append("--onnx-include-dir")
|
||||||
|
if a.core_count is None:
|
||||||
|
missing_args.append("--core-count")
|
||||||
if missing_args:
|
if missing_args:
|
||||||
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
|
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
|
||||||
|
|
||||||
@@ -117,6 +123,8 @@ def main():
|
|||||||
result = validate_network(
|
result = validate_network(
|
||||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count,
|
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count,
|
||||||
|
pim_merge_scheduler=a.pim_merge_scheduler,
|
||||||
|
command_timeout_seconds=a.command_timeout_seconds,
|
||||||
threshold=a.threshold,
|
threshold=a.threshold,
|
||||||
seed=a.seed,
|
seed=a.seed,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
|
|||||||
+24
-16
@@ -142,8 +142,8 @@ class ProgressReporter:
|
|||||||
self.rendered_width = 0
|
self.rendered_width = 0
|
||||||
|
|
||||||
|
|
||||||
def run_command(cmd, cwd=None, reporter=None):
|
def run_command(cmd, cwd=None, reporter=None, timeout_sec=None):
|
||||||
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
|
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter, timeout_sec=timeout_sec)
|
||||||
|
|
||||||
|
|
||||||
def clean_workspace_artifacts(workspace_dir, model_stem):
|
def clean_workspace_artifacts(workspace_dir, model_stem):
|
||||||
@@ -186,21 +186,22 @@ def print_info(reporter, message):
|
|||||||
reporter.log(f" {message}")
|
reporter.log(f" {message}")
|
||||||
|
|
||||||
|
|
||||||
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
|
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None, timeout_sec=None):
|
||||||
stem = network_onnx_path.stem
|
stem = network_onnx_path.stem
|
||||||
onnx_ir_base = raptor_dir / stem
|
onnx_ir_base = raptor_dir / stem
|
||||||
runner_base = runner_dir / stem
|
runner_base = runner_dir / stem
|
||||||
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
|
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"],
|
||||||
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
|
reporter=reporter, timeout_sec=timeout_sec)
|
||||||
|
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter, timeout_sec=timeout_sec)
|
||||||
network_so_path = runner_base.with_suffix(".so")
|
network_so_path = runner_base.with_suffix(".so")
|
||||||
network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
|
network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
|
||||||
onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
|
onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
|
||||||
return network_so_path, network_mlir_path
|
return network_so_path, network_mlir_path
|
||||||
|
|
||||||
|
|
||||||
def build_onnx_runner(source_dir, build_dir, reporter=None):
|
def build_onnx_runner(source_dir, build_dir, reporter=None, timeout_sec=None):
|
||||||
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
|
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter, timeout_sec=timeout_sec)
|
||||||
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
|
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter, timeout_sec=timeout_sec)
|
||||||
return build_dir / "runner"
|
return build_dir / "runner"
|
||||||
|
|
||||||
|
|
||||||
@@ -214,13 +215,14 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
|||||||
return ",".join(ranges)
|
return ",".join(ranges)
|
||||||
|
|
||||||
|
|
||||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None, timeout_sec=None):
|
||||||
run_command(
|
run_command(
|
||||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
|
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
|
||||||
"--",
|
"--",
|
||||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||||
cwd=simulator_dir,
|
cwd=simulator_dir,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
|
timeout_sec=timeout_sec,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -267,8 +269,10 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
|||||||
|
|
||||||
|
|
||||||
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||||
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
|
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None,
|
||||||
seed=0, reporter=None, model_index=1, model_total=1, verbose=False):
|
pim_merge_scheduler="peft", threshold=1e-3,
|
||||||
|
seed=0, reporter=None, model_index=1, model_total=1, verbose=False,
|
||||||
|
command_timeout_seconds=60.0):
|
||||||
network_onnx_path = Path(network_onnx_path).resolve()
|
network_onnx_path = Path(network_onnx_path).resolve()
|
||||||
raptor_path = Path(raptor_path).resolve()
|
raptor_path = Path(raptor_path).resolve()
|
||||||
onnx_include_dir = Path(onnx_include_dir).resolve()
|
onnx_include_dir = Path(onnx_include_dir).resolve()
|
||||||
@@ -292,7 +296,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
try:
|
try:
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||||
network_so_path, network_mlir_path = compile_onnx_network(
|
network_so_path, network_mlir_path = compile_onnx_network(
|
||||||
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
|
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter,
|
||||||
|
timeout_sec=command_timeout_seconds)
|
||||||
print_info(reporter, f"MLIR saved to {network_mlir_path}")
|
print_info(reporter, f"MLIR saved to {network_mlir_path}")
|
||||||
print_info(reporter, f"Shared library saved to {network_so_path}")
|
print_info(reporter, f"Shared library saved to {network_so_path}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
@@ -300,7 +305,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
||||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
|
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
|
||||||
verbose=False)
|
verbose=False)
|
||||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
|
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter,
|
||||||
|
timeout_sec=command_timeout_seconds)
|
||||||
print_info(reporter, f"Runner built at {runner_path}")
|
print_info(reporter, f"Runner built at {runner_path}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
@@ -316,14 +322,15 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
Path.mkdir(out_dir, exist_ok=True)
|
Path.mkdir(out_dir, exist_ok=True)
|
||||||
run_cmd = [runner_path, *flags]
|
run_cmd = [runner_path, *flags]
|
||||||
run_cmd += ["--save-csv-dir", f"{out_dir}"]
|
run_cmd += ["--save-csv-dir", f"{out_dir}"]
|
||||||
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
|
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter, timeout_sec=command_timeout_seconds)
|
||||||
print_info(reporter, f"Reference outputs saved to {out_dir}")
|
print_info(reporter, f"Reference outputs saved to {out_dir}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||||
pim_pass_timings = compile_with_raptor(
|
pim_pass_timings = compile_with_raptor(
|
||||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
|
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
|
||||||
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
|
core_count=core_count, pim_merge_scheduler=pim_merge_scheduler,
|
||||||
|
cwd=raptor_dir, verbose=verbose, reporter=reporter, timeout_sec=command_timeout_seconds)
|
||||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
@@ -334,7 +341,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
Path.mkdir(simulation_dir, exist_ok=True)
|
Path.mkdir(simulation_dir, exist_ok=True)
|
||||||
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
|
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
|
||||||
output_bin_path = simulation_dir / "out.bin"
|
output_bin_path = simulation_dir / "out.bin"
|
||||||
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
|
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter,
|
||||||
|
timeout_sec=command_timeout_seconds)
|
||||||
print_info(reporter, f"Simulator output saved to {output_bin_path}")
|
print_info(reporter, f"Simulator output saved to {output_bin_path}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user