add peft scheduling
Validate Operations / validate-operations (push) Has been cancelled

better deadlock report by pim simulator
This commit is contained in:
NiccoloN
2026-05-18 12:09:27 +02:00
parent de0a2f4561
commit f1602c0550
26 changed files with 1215 additions and 113 deletions
+5 -2
View File
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--core-count=<N>` — number of cores. Required for PIM compilation.
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
merge-compute-nodes pass (default: `peft`).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
--onnx-include-dir ../onnx-mlir/include \
--core-count 1000
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
.lock()
.unwrap()
.init(executor.cpu().num_core(), args.output.clone());
executor.execute();
executor.execute()?;
dump_memory(executor, &args)?;
Ok(())
}
@@ -1,5 +1,6 @@
#![allow(unused)]
use anyhow::{Result, bail};
use std::{
collections::{HashMap, HashSet},
time::{Duration, SystemTime},
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
send_recv: SendRecv,
}
struct DeadlockInfo {
cycle: String,
states: String,
}
fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0;
let mut progress = 0;
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
}
}
pub fn execute<'b>(&'b mut self)
pub fn execute<'b>(&'b mut self) -> Result<()>
where
'a: 'b,
{
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
}
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions);
check_cycle(cpu, cores_instructions, send_recv);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
now = SystemTime::now();
}
}
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
}
print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
if cores_instructions
.iter()
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
{
bail!("Execution stalled with unfinished instructions");
}
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
Ok(())
}
pub fn cpu(&self) -> &CPU<'a> {
@@ -201,12 +228,12 @@ impl<'a> Executable<'a> {
}
}
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
#[derive(Debug, PartialEq, Eq)]
enum CoreState {
SendingTo(i32),
ReceivingFrom(i32),
Working,
SendingTo(i32, i32),
ReceivingFrom(i32, i32),
Working,
Halted,
}
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core));
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
} else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core));
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
} else {
states.insert(this_core, CoreState::Working);
}
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
for (&core_id, state) in states.iter() {
match state {
CoreState::SendingTo(target_core) => {
CoreState::SendingTo(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id) {
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::ReceivingFrom(target_core) => {
CoreState::ReceivingFrom(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id) {
if target_state != &CoreState::SendingTo(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
.collect::<Vec<_>>()
.join(" -> ");
let cycle = cycle
.iter()
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source)
}
CoreState::Working => format!("core {} working", core),
CoreState::Halted => format!("core {} halted", core),
})
})
.collect::<Vec<_>>()
.join(", ");
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
// bail!("Deadlock detected: {}", cycle_msg);
break; // Stop tracing
return Some(DeadlockInfo {
cycle: cycle_msg,
states: states_msg,
});
}
// Hit a known branch that didn't result in a cycle
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
current_core = waiting_for;
}
}
None
}
fn handle_wait_sync<'a, 'b, 'c>(
+22 -3
View File
@@ -1,5 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
llvm::cl::init(4000));
llvm::cl::opt<bool>
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir
+9
View File
@@ -20,8 +20,14 @@ typedef enum {
EmitPimCodegen = 3
} PimEmissionTargetType;
typedef enum {
MergeSchedulerPeft = 0,
MergeSchedulerDcp = 1,
} PimMergeSchedulerType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
// This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the
// wanted tile is generated by two separate operands of the concat. If this is
+1
View File
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) {
verifyExplicitPimCoreCount();
if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM.
@@ -1,4 +1,5 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
@@ -48,6 +49,13 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
for (Value result : user->getResults())
pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
if (initArg == value)
pendingValues.push_back(forOp.getResult(index));
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
+4
View File
@@ -9,6 +9,10 @@ add_pim_library(SpatialOps
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
@@ -2,64 +2,27 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <vector>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
// A scheduling identity that covers both spat.compute and scheduled shards of
// spat.compute_batch.
struct ComputeInstance {
mlir::Operation* op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
struct DCPAnalysisResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
};
#include "../Scheduling/MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation* entryOp;
mlir::Operation *entryOp;
DCPAnalysisResult run();
public:
DCPAnalysis(mlir::Operation* op)
DCPAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
DCPAnalysisResult& getResult() { return result; }
DCPAnalysisResult &getResult() { return result; }
};
} // namespace spatial
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
};
} // namespace llvm
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
@@ -36,12 +36,13 @@
#include <utility>
#include <vector>
#include "DCPGraph/DCPAnalysis.hpp"
#include "RegularOpCompaction.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -344,11 +345,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
++opIts[groupIndex];
}
llvm::stable_sort(entries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
@@ -356,8 +352,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchReceiveEntry& entry : entries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
@@ -384,11 +379,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
++opIts[groupIndex];
}
llvm::stable_sort(entries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
@@ -396,8 +386,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchSendEntry& entry : entries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
@@ -983,7 +972,7 @@ public:
func::FuncOp func = getOperation();
Location loc = func.getLoc();
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
spatial::MergeScheduleResult& analysisResult = getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
DenseSet<Operation*> toEraseSet;
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
toEraseSet.insert(instance.op);
@@ -994,6 +983,7 @@ public:
size_t cpu = 0;
size_t slot = 0;
size_t order = 0;
size_t executionOrder = 0;
};
struct ChannelInfo {
int64_t channelId = -1;
@@ -1117,6 +1107,10 @@ public:
return lhs.slot < rhs.slot;
return lhs.order < rhs.order;
});
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
task.executionOrder = executionOrder;
taskByKey[task.key].executionOrder = executionOrder;
}
}
std::function<bool(Operation*)> isInternalInputOp = [&](Operation* op) {
@@ -1196,7 +1190,8 @@ public:
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0});
perResultChannels[producerRef->resultIndex].push_back(
{info, task.key, inputIndex, task.executionOrder, 0});
}
continue;
}
@@ -1271,6 +1266,18 @@ public:
}
}
for (const auto& taskSends : remoteSendsByTask) {
for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo& sendInfo : sendInfos) {
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
}
}
}
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
for (auto& taskSends : remoteSendsByTask) {
for (const auto& sendInfos : taskSends.second) {
@@ -1601,6 +1608,7 @@ public:
for (Operation* op : orderedUsersToMove)
op->moveBefore(returnOp);
orderBilateralChannelOps(func);
rebatchEquivalentComputes(func, nextChannelId);
compactScalarChannelRuns(func, nextChannelId);
compactBatchChannelRuns(func);
@@ -1632,7 +1640,7 @@ public:
private:
std::pair<SpatCompute, ComputeValueResults>
createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const DCPAnalysisResult& analysisResult) {
createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const spatial::MergeScheduleResult& analysisResult) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
@@ -1712,7 +1720,7 @@ private:
uint32_t firstLane,
uint32_t laneCount,
size_t currentCpu,
const DCPAnalysisResult& analysisResult,
const spatial::MergeScheduleResult& analysisResult,
std::optional<uint64_t> rebatchPhase = std::nullopt) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
@@ -11,6 +11,7 @@
#include "llvm/ADT/SmallVector.h"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -166,6 +167,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
}
static bool isForwardedChannelPayload(Value value, Block& block) {
Operation* op = value.getDefiningOp();
if (!op || op->getBlock() != &block)
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
}
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
RegularChunk chunk;
chunk.startOp = startOp.getOperation();
@@ -319,6 +331,76 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
} // namespace
void orderBilateralChannelOps(func::FuncOp funcOp) {
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
continue;
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
Block& block = compute.getBody().front();
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
for (Operation& op : block) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
continue;
}
Operation* firstMatchingSend = nullptr;
for (Operation* previous = receiveOp->getPrevNode(); previous; previous = previous->getPrevNode()) {
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(previous);
if (!sendOp || sendOp.getSourceCoreId() != static_cast<uint32_t>(coreId)
|| sendOp.getTargetCoreId() != receiveOp.getSourceCoreId()
|| !isForwardedChannelPayload(sendOp.getInput(), block)) {
continue;
}
firstMatchingSend = sendOp.getOperation();
}
if (firstMatchingSend)
moves.push_back({receiveOp, firstMatchingSend});
}
for (auto [receiveOp, insertionPoint] : moves)
receiveOp->moveBefore(insertionPoint);
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
++it;
continue;
}
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType
|| current.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
break;
}
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
SmallVector<spatial::SpatChannelReceiveOp> sorted(run);
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
});
Block::iterator insertIt = runIt;
for (auto op : sorted)
op->moveBefore(&block, insertIt);
}
it = runIt;
}
}
}
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
@@ -369,8 +451,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
@@ -451,8 +532,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
@@ -636,7 +716,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
}
compactRegularChunkRun(rewriter, run);
it = runIt;
it = block.begin();
}
};
@@ -6,6 +6,7 @@
namespace onnx_mlir {
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
@@ -0,0 +1,272 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <limits>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "ComputeGraph.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
Weight getComputeBodyWeight(Region &body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto &block : body)
for ([[maybe_unused]] auto &op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
CrossbarUsage crossbarUsage = 0;
for (auto &block : body)
for (auto &op : block)
if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
bool isUsedAsWeightOnly(Operation *producerOp) {
if (producerOp->getNumResults() == 0)
return false;
for (Value result : producerOp->getResults()) {
if (result.use_empty())
return false;
for (Operation *user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result))
return false;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
if (!llvm::is_contained(batch.getWeights(), result))
return false;
continue;
}
return false;
}
}
return true;
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge &edge : edges) {
if (edge.source == edge.target)
continue;
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (const auto &[key, weight] : edgeWeights)
aggregatedEdges.push_back({key.first, key.second, weight});
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
if (lhs.source != rhs.source)
return lhs.source < rhs.source;
return lhs.target < rhs.target;
});
return aggregatedEdges;
}
} // namespace
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
}
namespace {
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
}
} // namespace
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
static_cast<CrossbarUsage>(instance.laneCount));
}
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
llvm::SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
inputs.push_back(batch.getInputs()[lane]);
return inputs;
}
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(spatCompute.getWeights().begin(), spatCompute.getWeights().end());
auto batch = cast<SpatComputeBatch>(instance.op);
llvm::SmallVector<Value, 4> weights;
weights.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
weights.push_back(batch.getWeights()[lane]);
return weights;
}
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto spatCompute = dyn_cast<SpatCompute>(op))
return ComputeInstance {spatCompute.getOperation(), 0, 1};
if (auto batch = dyn_cast<SpatComputeBatch>(op))
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
return std::nullopt;
}
ComputeGraph buildComputeGraph(Operation *entryOp) {
ComputeGraph graph;
for (Region &region : entryOp->getRegions()) {
for (Block &block : region) {
for (Operation &op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue;
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
size_t index = graph.nodes.size();
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (isUsedAsWeightOnly(batch.getOperation()))
continue;
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
}
}
}
}
}
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) {
auto producerInstance = getComputeProducerInstance(input);
if (!producerInstance)
continue;
auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
}
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
graph.successors.assign(graph.nodes.size(), {});
graph.predecessors.assign(graph.nodes.size(), {});
for (const ComputeGraphEdge &edge : graph.edges) {
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
}
return graph;
}
bool verifyAcyclic(const ComputeGraph &graph) {
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
std::queue<size_t> readyNodes;
for (size_t node = 0; node < graph.nodes.size(); ++node) {
remainingParents[node] = graph.predecessors[node].size();
if (remainingParents[node] == 0)
readyNodes.push(node);
}
size_t visited = 0;
while (!readyNodes.empty()) {
size_t node = readyNodes.front();
readyNodes.pop();
++visited;
for (const auto &[child, weight] : graph.successors[node]) {
(void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyNodes.push(child);
}
}
return visited == graph.nodes.size();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,53 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include "../DCPGraph/Utils.hpp"
#include "ComputeInstance.hpp"
namespace onnx_mlir {
namespace spatial {
struct ComputeGraphNode {
ComputeInstance instance;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
size_t originalOrder = 0;
};
struct ComputeGraphEdge {
size_t source = 0;
size_t target = 0;
Weight transferCost = 0;
};
struct ComputeGraph {
llvm::SmallVector<ComputeGraphNode> nodes;
llvm::SmallVector<ComputeGraphEdge> edges;
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
};
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
bool verifyAcyclic(const ComputeGraph &graph);
size_t getBatchChunkTargetCount(int32_t laneCount);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
Weight getComputeInstanceWeight(const ComputeInstance &instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include <cstdint>
namespace onnx_mlir {
namespace spatial {
struct ComputeInstance {
mlir::Operation *op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance &other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
} // namespace spatial
} // namespace onnx_mlir
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
}
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
const onnx_mlir::spatial::ComputeInstance &rhs) {
return lhs == rhs;
}
};
} // namespace llvm
@@ -0,0 +1,62 @@
#include "llvm/ADT/SmallVector.h"
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
namespace spatial {
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context) {
llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys;
llvm::SmallVector<IndexedEdge> edges;
nodeWeights.reserve(graph.nodes.size());
nodeCrossbarUsage.reserve(graph.nodes.size());
nodeOrderKeys.reserve(graph.nodes.size());
edges.reserve(graph.edges.size());
for (const ComputeGraphNode &node : graph.nodes) {
nodeWeights.push_back(node.weight);
nodeCrossbarUsage.push_back(node.crossbarUsage);
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
}
for (const ComputeGraphEdge &edge : graph.edges) {
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
graphDCP.runDcp();
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(graph.nodes.size());
for (const ComputeGraphNode &node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance);
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (const auto &[slot, task] : llvm::enumerate(scheduledTasks)) {
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
result.computeToCpuMap[instance] = cpu;
result.computeToCpuSlotMap[instance] = slot;
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
}
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
result.cpuToLastComputeMap[cpu] = lastInstance;
result.isLastComputeOfCpu.insert(lastInstance);
}
return result;
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,14 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "ComputeGraph.hpp"
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,25 @@
#pragma once
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstddef>
#include <cstdint>
#include <vector>
#include "ComputeInstance.hpp"
namespace onnx_mlir {
namespace spatial {
struct MergeScheduleResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
};
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,133 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp"
#include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
namespace spatial {
namespace {
MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft:
return MergeSchedulerKind::Peft;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
}
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
if (!result.computeToCpuMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
if (!result.computeToCpuSlotMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
if (!result.computeToAestMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing start time");
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
});
CrossbarUsage usedCrossbars = 0;
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
if (scheduledTasks[slot].first != slot)
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
if (usedCrossbars > crossbarCapacity)
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
}
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
if (!result.isLastComputeOfCpu.count(expectedLast))
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
}
for (const ComputeGraphEdge &edge : graph.edges) {
const ComputeInstance source = graph.nodes[edge.source].instance;
const ComputeInstance target = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
const size_t targetCpu = result.computeToCpuMap.lookup(target);
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
graph.nodes[edge.target].originalOrder)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
}
}
} // namespace
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
MergeScheduleResult MergeSchedulingAnalysis::run() {
verifyExplicitPimCoreCount();
ComputeGraph graph = buildComputeGraph(entryOp);
if (!verifyAcyclic(graph))
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
MergeSchedulingOptions options;
options.kind = getSchedulerKind();
if (coresCount.getValue() > 0)
options.processorCount = static_cast<size_t>(coresCount.getValue());
MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler(
graph,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
else {
schedule = DCPAnalysis(entryOp).getResult();
}
if (options.kind == MergeSchedulerKind::Peft)
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
return schedule;
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,36 @@
#pragma once
#include "mlir/IR/Operation.h"
#include <cstddef>
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
enum class MergeSchedulerKind {
Dcp,
Peft,
};
struct MergeSchedulingOptions {
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
size_t processorCount = 0;
bool allowDcpFallbackForAutoCoreCount = true;
};
class MergeSchedulingAnalysis {
public:
explicit MergeSchedulingAnalysis(mlir::Operation *op);
MergeScheduleResult &getResult() { return result; }
private:
mlir::Operation *entryOp = nullptr;
MergeScheduleResult result;
MergeScheduleResult run();
};
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,264 @@
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <queue>
#include <vector>
#include "PeftScheduler.hpp"
namespace onnx_mlir {
namespace spatial {
namespace {
struct ScheduledTask {
size_t processor = std::numeric_limits<size_t>::max();
Time startTime = 0;
Time endTime = 0;
size_t slot = 0;
};
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph &graph) {
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
std::queue<size_t> readySinks;
std::vector<std::vector<size_t>> reverseLevels;
for (size_t node = 0; node < graph.nodes.size(); ++node) {
remainingSuccessors[node] = graph.successors[node].size();
if (remainingSuccessors[node] == 0)
readySinks.push(node);
}
size_t levelizedCount = 0;
while (!readySinks.empty()) {
size_t levelSize = readySinks.size();
std::vector<size_t> levelNodes;
levelNodes.reserve(levelSize);
for (size_t i = 0; i < levelSize; ++i) {
size_t node = readySinks.front();
readySinks.pop();
levelNodes.push_back(node);
++levelizedCount;
for (const auto &[pred, weight] : graph.predecessors[node]) {
(void) weight;
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
if (--remainingSuccessors[pred] == 0)
readySinks.push(pred);
}
}
reverseLevels.push_back(std::move(levelNodes));
}
if (levelizedCount != graph.nodes.size())
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
return reverseLevels;
}
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
constexpr size_t kMaxOctTableBytes = 1ull << 30;
if (nodeCount == 0 || processorCount == 0)
return;
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
size_t rowBytes = processorCount * sizeof(Time);
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
size_t totalBytes = nodeCount * rowBytes;
if (totalBytes > kMaxOctTableBytes) {
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
totalBytes / (1024 * 1024))
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
}
} // namespace
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options) {
const size_t nodeCount = graph.nodes.size();
const size_t processorCount = options.processorCount;
if (processorCount == 0)
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
verifyOctTableSize(nodeCount, processorCount);
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
for (const std::vector<size_t> &levelNodes : reverseLevels) {
auto computeNodeOct = [&](size_t levelIndex) {
size_t task = levelNodes[levelIndex];
std::vector<Time> maxVals(processorCount, 0);
for (const auto &[succ, comm] : graph.successors[task]) {
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
for (size_t processor = 0; processor < processorCount; ++processor) {
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], graph.nodes[succ].weight);
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
maxVals[processor] = std::max(maxVals[processor], bestSucc);
}
}
Time minForPreds = std::numeric_limits<Time>::max();
for (size_t processor = 0; processor < processorCount; ++processor) {
oct[task * processorCount + processor] = maxVals[processor];
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], graph.nodes[task].weight));
}
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
};
if (options.context != nullptr)
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
else
for (size_t i = 0; i < levelNodes.size(); ++i)
computeNodeOct(i);
}
struct RankEntry {
long double rank = 0.0L;
size_t node = 0;
size_t originalOrder = 0;
};
std::vector<RankEntry> ranks(nodeCount);
auto computeRank = [&](size_t node) {
long double rank = 0.0L;
for (size_t processor = 0; processor < processorCount; ++processor)
rank += static_cast<long double>(oct[node * processorCount + processor]);
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
};
if (options.context != nullptr)
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
else
for (size_t node = 0; node < nodeCount; ++node)
computeRank(node);
auto readyCompare = [&](size_t lhs, size_t rhs) {
const RankEntry &lhsRank = ranks[lhs];
const RankEntry &rhsRank = ranks[rhs];
if (lhsRank.rank != rhsRank.rank)
return lhsRank.rank < rhsRank.rank;
if (lhsRank.originalOrder != rhsRank.originalOrder)
return lhsRank.originalOrder > rhsRank.originalOrder;
return lhs > rhs;
};
std::vector<int> remainingParents(nodeCount, 0);
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
for (size_t node = 0; node < nodeCount; ++node) {
remainingParents[node] = graph.predecessors[node].size();
if (remainingParents[node] == 0)
readyQueue.push(node);
}
std::vector<char> scheduled(nodeCount, false);
std::vector<Time> processorAvailable(processorCount, 0);
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
std::vector<ScheduledTask> schedules(nodeCount);
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
size_t scheduledCount = 0;
while (!readyQueue.empty()) {
size_t task = readyQueue.top();
readyQueue.pop();
if (scheduled[task])
continue;
size_t bestProcessor = std::numeric_limits<size_t>::max();
Time bestEst = 0;
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
bool crossbarRejected = false;
for (size_t processor = 0; processor < processorCount; ++processor) {
if (graph.nodes[task].crossbarUsage != 0 &&
addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
crossbarRejected = true;
continue;
}
Time dataReady = 0;
for (const auto &[pred, comm] : graph.predecessors[task]) {
const ScheduledTask &predSchedule = schedules[pred];
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
}
Time est = std::max(processorAvailable[processor], dataReady);
Time eft = addOrMax(est, graph.nodes[task].weight);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft) ||
(oeft == bestOeft && eft == bestEft && est < bestEst) ||
(oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
}
}
if (bestProcessor == std::numeric_limits<size_t>::max()) {
if (crossbarRejected) {
std::string message =
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
graph.nodes[task].originalOrder,
options.crossbarCapacity)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
graph.nodes[task].originalOrder,
processorCount)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
schedules[task] = {bestProcessor, bestEst, bestEft, tasksByProcessor[bestProcessor].size()};
scheduled[task] = true;
++scheduledCount;
processorAvailable[bestProcessor] = bestEft;
processorCrossbars[bestProcessor] =
addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
tasksByProcessor[bestProcessor].push_back(task);
for (const auto &[child, weight] : graph.successors[task]) {
(void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyQueue.push(child);
}
}
if (scheduledCount != nodeCount)
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(nodeCount);
for (const ComputeGraphNode &node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance);
for (size_t processor = 0; processor < processorCount; ++processor) {
for (size_t task : tasksByProcessor[processor]) {
const ComputeInstance instance = graph.nodes[task].instance;
result.computeToCpuMap[instance] = processor;
result.computeToCpuSlotMap[instance] = schedules[task].slot;
result.computeToAestMap[instance] = schedules[task].startTime;
}
if (!tasksByProcessor[processor].empty()) {
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
result.cpuToLastComputeMap[processor] = lastInstance;
result.isLastComputeOfCpu.insert(lastInstance);
}
}
return result;
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "ComputeGraph.hpp"
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
struct PeftScheduleOptions {
size_t processorCount = 0;
CrossbarUsage crossbarCapacity = 0;
mlir::MLIRContext *context = nullptr;
};
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
} // namespace spatial
} // namespace onnx_mlir
+4 -1
View File
@@ -41,7 +41,8 @@ def _format_command(cmd):
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
crossbar_size, crossbar_count, core_count=None, pim_merge_scheduler="peft",
cwd=None, verbose=False, reporter=None, timeout_sec=None):
# Define the arguments, with the possibility to set crossbar size and count
args = [
network_path,
@@ -51,6 +52,7 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
"--EmitPimCodegen",
f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}",
f"--pim-merge-scheduler={pim_merge_scheduler}",
]
if core_count is not None:
args.append(f"--core-count={core_count}")
@@ -69,6 +71,7 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
cwd=cwd,
reporter=reporter,
capture_output=True,
timeout_sec=timeout_sec,
)
if reporter is None:
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
+19 -6
View File
@@ -3,6 +3,7 @@ import os
import pty
import selectors
import subprocess
import time
MAX_ERROR_OUTPUT_BYTES = 8192
@@ -16,16 +17,26 @@ def _read_chunk(fd, treat_eio_as_eof=False):
raise
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True, timeout_sec=None):
selector = selectors.DefaultSelector()
recent_output = bytearray()
captured_output = bytearray()
deadline = None if timeout_sec is None else time.monotonic() + timeout_sec
try:
selector.register(fd, selectors.EVENT_READ)
while selector.get_map():
for key, _ in selector.select():
select_timeout = None
if deadline is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
process.kill()
process.wait()
raise subprocess.TimeoutExpired(process.args, timeout_sec, output=bytes(captured_output))
select_timeout = min(1.0, remaining)
for key, _ in selector.select(select_timeout):
data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof)
if not data:
selector.unregister(key.fileobj)
@@ -53,7 +64,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
return bytes(captured_output)
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False, timeout_sec=None):
if reporter is None:
if capture_output:
completed = subprocess.run(
@@ -62,9 +73,10 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=timeout_sec,
)
return completed.stdout.decode("utf-8", errors="replace")
subprocess.run(cmd, cwd=cwd, check=True)
subprocess.run(cmd, cwd=cwd, check=True, timeout=timeout_sec)
return None
stream_output = bool(getattr(reporter, "verbose", False))
@@ -74,6 +86,7 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=timeout_sec,
)
if completed.returncode != 0:
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
@@ -89,7 +102,7 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
output = _stream_output(process.stdout.fileno(), process, reporter)
output = _stream_output(process.stdout.fileno(), process, reporter, timeout_sec=timeout_sec)
return output.decode("utf-8", errors="replace") if capture_output else None
try:
@@ -102,5 +115,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
finally:
os.close(slave_fd)
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True, timeout_sec=timeout_sec)
return output.decode("utf-8", errors="replace") if capture_output else None
+9 -1
View File
@@ -64,7 +64,11 @@ def main():
ap.add_argument("--crossbar-size", type=int, default=64)
ap.add_argument("--crossbar-count", type=int, default=8)
ap.add_argument("--core-count", type=int, default=None,
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
help="Core count to pass to Raptor. Required for PIM validation.")
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
help="Scheduler used by the Spatial merge-compute-nodes pass.")
ap.add_argument("--command-timeout-seconds", type=float, default=60.0,
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
ap.add_argument("--clean", action="store_true",
help="Remove generated validation artifacts under each model workspace and exit.")
ap.add_argument("--verbose", action="store_true",
@@ -98,6 +102,8 @@ def main():
missing_args.append("--raptor-path")
if not a.onnx_include_dir:
missing_args.append("--onnx-include-dir")
if a.core_count is None:
missing_args.append("--core-count")
if missing_args:
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
@@ -117,6 +123,8 @@ def main():
result = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count,
pim_merge_scheduler=a.pim_merge_scheduler,
command_timeout_seconds=a.command_timeout_seconds,
threshold=a.threshold,
seed=a.seed,
reporter=reporter,
+24 -16
View File
@@ -142,8 +142,8 @@ class ProgressReporter:
self.rendered_width = 0
def run_command(cmd, cwd=None, reporter=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
def run_command(cmd, cwd=None, reporter=None, timeout_sec=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter, timeout_sec=timeout_sec)
def clean_workspace_artifacts(workspace_dir, model_stem):
@@ -186,21 +186,22 @@ def print_info(reporter, message):
reporter.log(f" {message}")
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None, timeout_sec=None):
stem = network_onnx_path.stem
onnx_ir_base = raptor_dir / stem
runner_base = runner_dir / stem
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"],
reporter=reporter, timeout_sec=timeout_sec)
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter, timeout_sec=timeout_sec)
network_so_path = runner_base.with_suffix(".so")
network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
return network_so_path, network_mlir_path
def build_onnx_runner(source_dir, build_dir, reporter=None):
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter)
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter)
def build_onnx_runner(source_dir, build_dir, reporter=None, timeout_sec=None):
run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter, timeout_sec=timeout_sec)
run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter, timeout_sec=timeout_sec)
return build_dir / "runner"
@@ -214,13 +215,14 @@ def build_dump_ranges(config_path, outputs_descriptor):
return ",".join(ranges)
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None, timeout_sec=None):
run_command(
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
"--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir,
reporter=reporter,
timeout_sec=timeout_sec,
)
@@ -267,8 +269,10 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
seed=0, reporter=None, model_index=1, model_total=1, verbose=False):
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None,
pim_merge_scheduler="peft", threshold=1e-3,
seed=0, reporter=None, model_index=1, model_total=1, verbose=False,
command_timeout_seconds=60.0):
network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve()
@@ -292,7 +296,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
network_so_path, network_mlir_path = compile_onnx_network(
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter)
network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter,
timeout_sec=command_timeout_seconds)
print_info(reporter, f"MLIR saved to {network_mlir_path}")
print_info(reporter, f"Shared library saved to {network_so_path}")
reporter.advance()
@@ -300,7 +305,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter,
timeout_sec=command_timeout_seconds)
print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
@@ -316,14 +322,15 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
Path.mkdir(out_dir, exist_ok=True)
run_cmd = [runner_path, *flags]
run_cmd += ["--save-csv-dir", f"{out_dir}"]
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter)
run_command(run_cmd, cwd=runner_build_dir, reporter=reporter, timeout_sec=command_timeout_seconds)
print_info(reporter, f"Reference outputs saved to {out_dir}")
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
core_count=core_count, pim_merge_scheduler=pim_merge_scheduler,
cwd=raptor_dir, verbose=verbose, reporter=reporter, timeout_sec=command_timeout_seconds)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance()
@@ -334,7 +341,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
Path.mkdir(simulation_dir, exist_ok=True)
dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor)
output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter)
run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter,
timeout_sec=command_timeout_seconds)
print_info(reporter, f"Simulator output saved to {output_bin_path}")
reporter.advance()