diff --git a/README.md b/README.md index acae64c..b576cb4 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM: run only the codegen tail. - `--crossbar-size=` / `--crossbar-count=` — crossbar dimensions and per-core count. -- `--core-count=` — number of cores (`-1` picks the minimum). +- `--core-count=` — number of cores. Required for PIM compilation. +- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial + merge-compute-nodes pass (default: `peft`). - `--dcp-critical-window-size=` — DCP coarsening window (0 = legacy). - `--use-experimental-conv-impl` — alternative convolution lowering. - `--ignore-concat-error` — soft-fail corner case in `ConcatOp`. @@ -129,7 +131,8 @@ Per-operation validation (from `validation/`): ``` validate.py \ --raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ - --onnx-include-dir ../onnx-mlir/include + --onnx-include-dir ../onnx-mlir/include \ + --core-count 1000 ``` End-to-end network validation (example: first 4 layers of YOLOv11n): diff --git a/backend-simulators/pim/pim-simulator/src/bin/pim-simulator/main.rs b/backend-simulators/pim/pim-simulator/src/bin/pim-simulator/main.rs index ad2371c..bf50c95 100644 --- a/backend-simulators/pim/pim-simulator/src/bin/pim-simulator/main.rs +++ b/backend-simulators/pim/pim-simulator/src/bin/pim-simulator/main.rs @@ -67,7 +67,7 @@ fn main() -> Result<()> { .lock() .unwrap() .init(executor.cpu().num_core(), args.output.clone()); - executor.execute(); + executor.execute()?; dump_memory(executor, &args)?; Ok(()) } diff --git a/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs b/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs index 89a9866..ddb40ee 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs @@ -1,5 +1,6 @@ #![allow(unused)] +use anyhow::{Result, bail}; use std::{ collections::{HashMap, HashSet}, time::{Duration, SystemTime}, @@ -87,6 +88,11 @@ pub struct Executable<'a> { send_recv: SendRecv, } +struct DeadlockInfo { + cycle: String, + states: String, +} + fn print_status(core_instructions: &[CoreInstructions]) { let mut tot_instructions = 0; let mut progress = 0; @@ -118,7 +124,7 @@ impl<'a> Executable<'a> { } } - pub fn execute<'b>(&'b mut self) + pub fn execute<'b>(&'b mut self) -> Result<()> where 'a: 'b, { @@ -153,7 +159,13 @@ impl<'a> Executable<'a> { } if (now.elapsed().unwrap() > Duration::from_secs(5)) { print_status(cores_instructions); - check_cycle(cpu, cores_instructions, send_recv); + if let Some(deadlock) = detect_deadlock(cores_instructions) { + bail!( + "Deadlock cycle detected: {} [{}]", + deadlock.cycle, + deadlock.states + ); + } now = SystemTime::now(); } } @@ -178,8 +190,23 @@ impl<'a> Executable<'a> { } print_status(cores_instructions); + if let Some(deadlock) = detect_deadlock(cores_instructions) { + bail!( + "Deadlock cycle detected: {} [{}]", + deadlock.cycle, + deadlock.states + ); + } + if cores_instructions + .iter() + .any(|core_inst| core_inst.program_counter < core_inst.instructions.len()) + { + bail!("Execution stalled with unfinished instructions"); + } + #[cfg(feature = "profile_time")] TRACER.lock().unwrap().report(); + Ok(()) } pub fn cpu(&self) -> &CPU<'a> { @@ -201,12 +228,12 @@ impl<'a> Executable<'a> { } } -fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) { +fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option { #[derive(Debug, PartialEq, Eq)] enum CoreState { - SendingTo(i32), - ReceivingFrom(i32), - Working, + SendingTo(i32, i32), + ReceivingFrom(i32, i32), + Working, Halted, } @@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv let (this_core, target_core) = data.get_core_immcore(); if isa_recv(functor_address) { - states.insert(this_core, CoreState::ReceivingFrom(target_core)); + states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len())); } else if isa_send(functor_address) { - states.insert(this_core, CoreState::SendingTo(target_core)); + states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len())); } else { states.insert(this_core, CoreState::Working); } @@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv for (&core_id, state) in states.iter() { match state { - CoreState::SendingTo(target_core) => { + CoreState::SendingTo(target_core, size) => { let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); - if target_state != &CoreState::ReceivingFrom(core_id) { + if target_state != &CoreState::ReceivingFrom(core_id, *size) { wait_for.insert(core_id, *target_core); } } - CoreState::ReceivingFrom(target_core) => { + CoreState::ReceivingFrom(target_core, size) => { let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); - if target_state != &CoreState::SendingTo(core_id) { + if target_state != &CoreState::SendingTo(core_id, *size) { wait_for.insert(core_id, *target_core); } } @@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv .collect::>() .join(" -> "); + let cycle = cycle + .iter() + .copied() + .chain(std::iter::once(waiting_for)) + .collect::>(); let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); + let states_msg = cycle + .iter() + .filter_map(|core| { + states.get(core).map(|state| match state { + CoreState::SendingTo(target, size) => { + format!("core {} send {}B -> {}", core, size, target) + } + CoreState::ReceivingFrom(source, size) => { + format!("core {} recv {}B <- {}", core, size, source) + } + CoreState::Working => format!("core {} working", core), + CoreState::Halted => format!("core {} halted", core), + }) + }) + .collect::>() + .join(", "); - println!("Fatal: Deadlock cycle detected: {}", cycle_msg); - // bail!("Deadlock detected: {}", cycle_msg); - break; // Stop tracing + return Some(DeadlockInfo { + cycle: cycle_msg, + states: states_msg, + }); } // Hit a known branch that didn't result in a cycle @@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv current_core = waiting_for; } } + None } fn handle_wait_sync<'a, 'b, 'c>( diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 03fd678..37410ed 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -1,5 +1,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "llvm/Support/ErrorHandling.h" + #define DEBUG_TYPE "PimCompilerOptions" namespace onnx_mlir { @@ -13,6 +15,14 @@ llvm::cl::opt pimEmissionTarget( llvm::cl::init(EmitPimCodegen), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimMergeScheduler( + "pim-merge-scheduler", + llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), + llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), + llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), + llvm::cl::init(MergeSchedulerPeft), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt pimOnlyCodegen("pim-only-codegen", llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), @@ -30,19 +40,19 @@ llvm::cl::opt pimEmitJson("pim-emit-json", llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt - crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); + crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2)); llvm::cl::opt crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256)); llvm::cl::opt coresCount("core-count", - llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), + llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."), llvm::cl::init(-1)); llvm::cl::opt dcpCriticalWindowSize( "dcp-critical-window-size", llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. " - "Use 0 to run the legacy full-graph DCP analysis."), + "Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."), llvm::cl::init(4000)); llvm::cl::opt @@ -50,4 +60,13 @@ llvm::cl::opt llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::init(false)); +bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; } + +void verifyExplicitPimCoreCount() { + if (!hasExplicitPimCoreCount()) + llvm::report_fatal_error("PIM compilation requires an explicit --core-count="); + if (coresCount.getValue() <= 0) + llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer"); +} + } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 8fb1467..a5f1d16 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -20,8 +20,14 @@ typedef enum { EmitPimCodegen = 3 } PimEmissionTargetType; +typedef enum { + MergeSchedulerPeft = 0, + MergeSchedulerDcp = 1, +} PimMergeSchedulerType; + extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::opt pimEmissionTarget; +extern llvm::cl::opt pimMergeScheduler; extern llvm::cl::opt pimOnlyCodegen; extern llvm::cl::opt useExperimentalConvImpl; @@ -32,6 +38,9 @@ extern llvm::cl::opt crossbarCountInCore; extern llvm::cl::opt coresCount; extern llvm::cl::opt dcpCriticalWindowSize; +bool hasExplicitPimCoreCount(); +void verifyExplicitPimCoreCount(); + // This option, by default set to false, will ignore an error when resolving a // specific tiles of the operands of a concat. This specific case is when the // wanted tile is generated by two separate operands of the concat. If this is diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 1e1ed1e..bc72bd6 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef& module, PassManager& pm, EmissionTargetType& emissionTarget, std::string outputNameNoExt) { + verifyExplicitPimCoreCount(); if (pimOnlyCodegen) { // Skip all the lowering passes and directly generate code for PIM. diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp index 47cb267..a1c6f0a 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/DenseMap.h" @@ -48,6 +49,13 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapgetResults()) pendingValues.push_back(result); + if (auto forOp = dyn_cast(user)) { + for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { + if (initArg == value) + pendingValues.push_back(forOp.getResult(index)); + } + } + if (auto dpsOp = dyn_cast(user)) { for (OpResult result : user->getResults()) { OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result); diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 6f6a7ff..722c0ac 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -9,6 +9,10 @@ add_pim_library(SpatialOps SpatialOpsCanonicalization.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/RegularOpCompaction.cpp + Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp + Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp + Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp + Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index defcebf..580616c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -2,64 +2,27 @@ #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" - -#include -#include - -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -// A scheduling identity that covers both spat.compute and scheduled shards of -// spat.compute_batch. -struct ComputeInstance { - mlir::Operation* op = nullptr; - uint32_t laneStart = 0; - uint32_t laneCount = 1; - - bool operator==(const ComputeInstance& other) const { - return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount; - } -}; - -struct DCPAnalysisResult { - std::vector dominanceOrderCompute; - llvm::DenseMap computeToCpuMap; - llvm::DenseMap computeToCpuSlotMap; - llvm::DenseMap computeToAestMap; - llvm::DenseSet isLastComputeOfCpu; - llvm::DenseMap cpuToLastComputeMap; -}; +#include "../Scheduling/MergeSchedule.hpp" namespace onnx_mlir { namespace spatial { +using DCPAnalysisResult = MergeScheduleResult; + struct DCPAnalysis { private: DCPAnalysisResult result; - mlir::Operation* entryOp; + mlir::Operation *entryOp; DCPAnalysisResult run(); public: - DCPAnalysis(mlir::Operation* op) + DCPAnalysis(mlir::Operation *op) : entryOp(op) { result = run(); } - DCPAnalysisResult& getResult() { return result; } + DCPAnalysisResult &getResult() { return result; } }; } // namespace spatial } // namespace onnx_mlir -namespace llvm { -template <> -struct DenseMapInfo { - static ComputeInstance getEmptyKey() { - return {DenseMapInfo::getEmptyKey(), UINT32_MAX, UINT32_MAX}; - } - static ComputeInstance getTombstoneKey() { - return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; - } - static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); } - static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; } -}; -} // namespace llvm +using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 2b90118..3ed695d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -36,12 +36,13 @@ #include #include -#include "DCPGraph/DCPAnalysis.hpp" #include "RegularOpCompaction.hpp" +#include "Scheduling/MergeSchedulingAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -344,11 +345,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { {groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()}); ++opIts[groupIndex]; } - llvm::stable_sort(entries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; @@ -356,8 +352,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { sourceCoreIds.reserve(group.size()); targetCoreIds.reserve(group.size()); for (const BatchReceiveEntry& entry : entries) { - (void) entry; - channelIds.push_back(nextChannelId++); + channelIds.push_back(static_cast(entry.channelId)); sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); } @@ -384,11 +379,6 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()}); ++opIts[groupIndex]; } - llvm::stable_sort(entries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; @@ -396,8 +386,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { sourceCoreIds.reserve(group.size()); targetCoreIds.reserve(group.size()); for (const BatchSendEntry& entry : entries) { - (void) entry; - channelIds.push_back(nextChannelId++); + channelIds.push_back(static_cast(entry.channelId)); sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); } @@ -983,7 +972,7 @@ public: func::FuncOp func = getOperation(); Location loc = func.getLoc(); - DCPAnalysisResult& analysisResult = getAnalysis().getResult(); + spatial::MergeScheduleResult& analysisResult = getAnalysis().getResult(); DenseSet toEraseSet; for (ComputeInstance instance : analysisResult.dominanceOrderCompute) toEraseSet.insert(instance.op); @@ -994,6 +983,7 @@ public: size_t cpu = 0; size_t slot = 0; size_t order = 0; + size_t executionOrder = 0; }; struct ChannelInfo { int64_t channelId = -1; @@ -1117,6 +1107,10 @@ public: return lhs.slot < rhs.slot; return lhs.order < rhs.order; }); + for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) { + task.executionOrder = executionOrder; + taskByKey[task.key].executionOrder = executionOrder; + } } std::function isInternalInputOp = [&](Operation* op) { @@ -1196,7 +1190,8 @@ public: auto& perResultChannels = remoteSendsByTask[producerRef->instance]; if (perResultChannels.empty()) perResultChannels.resize(getTaskOutputTypes(producerIt->second).size()); - perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0}); + perResultChannels[producerRef->resultIndex].push_back( + {info, task.key, inputIndex, task.executionOrder, 0}); } continue; } @@ -1271,6 +1266,18 @@ public: } } + for (const auto& taskSends : remoteSendsByTask) { + for (const auto& sendInfos : taskSends.second) { + for (const RemoteSendInfo& sendInfo : sendInfos) { + auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer); + assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send"); + assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); + assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel"); + remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo; + } + } + } + DenseMap>> receiveQueuesByCpu; for (auto& taskSends : remoteSendsByTask) { for (const auto& sendInfos : taskSends.second) { @@ -1601,6 +1608,7 @@ public: for (Operation* op : orderedUsersToMove) op->moveBefore(returnOp); + orderBilateralChannelOps(func); rebatchEquivalentComputes(func, nextChannelId); compactScalarChannelRuns(func, nextChannelId); compactBatchChannelRuns(func); @@ -1632,7 +1640,7 @@ public: private: std::pair - createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const DCPAnalysisResult& analysisResult) { + createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const spatial::MergeScheduleResult& analysisResult) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); @@ -1712,7 +1720,7 @@ private: uint32_t firstLane, uint32_t laneCount, size_t currentCpu, - const DCPAnalysisResult& analysisResult, + const spatial::MergeScheduleResult& analysisResult, std::optional rebatchPhase = std::nullopt) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index 64cdfc7..815cd54 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/SmallVector.h" #include "RegularOpCompaction.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -166,6 +167,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); } +static bool isForwardedChannelPayload(Value value, Block& block) { + Operation* op = value.getDefiningOp(); + if (!op || op->getBlock() != &block) + return true; + + if (auto extractSliceOp = dyn_cast(op)) + return isForwardedChannelPayload(extractSliceOp.getSource(), block); + + return isa(op); +} + static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { RegularChunk chunk; chunk.startOp = startOp.getOperation(); @@ -319,6 +331,76 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef } // namespace +void orderBilateralChannelOps(func::FuncOp funcOp) { + for (auto compute : funcOp.getOps()) { + auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName); + if (!coreIdAttr) + continue; + + int32_t coreId = static_cast(coreIdAttr.getInt()); + Block& block = compute.getBody().front(); + SmallVector> moves; + + for (Operation& op : block) { + auto receiveOp = dyn_cast(&op); + if (!receiveOp || receiveOp.getTargetCoreId() != static_cast(coreId) + || receiveOp.getSourceCoreId() >= static_cast(coreId)) { + continue; + } + + Operation* firstMatchingSend = nullptr; + for (Operation* previous = receiveOp->getPrevNode(); previous; previous = previous->getPrevNode()) { + auto sendOp = dyn_cast(previous); + if (!sendOp || sendOp.getSourceCoreId() != static_cast(coreId) + || sendOp.getTargetCoreId() != receiveOp.getSourceCoreId() + || !isForwardedChannelPayload(sendOp.getInput(), block)) { + continue; + } + firstMatchingSend = sendOp.getOperation(); + } + + if (firstMatchingSend) + moves.push_back({receiveOp, firstMatchingSend}); + } + + for (auto [receiveOp, insertionPoint] : moves) + receiveOp->moveBefore(insertionPoint); + + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast(coreId)) { + ++it; + continue; + } + + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType + || current.getSourceCoreId() >= static_cast(coreId)) { + break; + } + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + SmallVector sorted(run); + llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) { + return lhs.getSourceCoreId() > rhs.getSourceCoreId(); + }); + Block::iterator insertIt = runIt; + for (auto op : sorted) + op->moveBefore(&block, insertIt); + } + + it = runIt; + } + } +} + void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { IRRewriter rewriter(funcOp.getContext()); @@ -369,8 +451,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { sourceCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size()); for (ReceiveEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); + channelIds.push_back(static_cast(entry.channelId)); sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); } @@ -451,8 +532,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { targetCoreIds.reserve(sortedEntries.size()); inputs.reserve(sortedEntries.size()); for (SendEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); + channelIds.push_back(static_cast(entry.channelId)); sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); inputs.push_back(entry.op.getInput()); @@ -636,7 +716,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) { } compactRegularChunkRun(rewriter, run); - it = runIt; + it = block.begin(); } }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp index 08b7d1e..79cdf09 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp @@ -6,6 +6,7 @@ namespace onnx_mlir { +void orderBilateralChannelOps(mlir::func::FuncOp funcOp); void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId); void compactBatchChannelRuns(mlir::func::FuncOp funcOp); void compactRegularOpRuns(mlir::func::FuncOp funcOp); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp new file mode 100644 index 0000000..b40be65 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -0,0 +1,272 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" + +#include +#include +#include +#include +#include +#include + +#include "ComputeGraph.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Support/TypeUtilities.hpp" + +namespace onnx_mlir { +namespace spatial { + +using namespace mlir; + +namespace { + +size_t getSchedulingCpuBudget() { + if (coresCount.getValue() > 0) + return static_cast(coresCount.getValue()); + return std::numeric_limits::max(); +} + +Weight getComputeBodyWeight(Region &body) { + constexpr Weight kOperationWeight = 100; + Weight numOperations = 0; + for (auto &block : body) + for ([[maybe_unused]] auto &op : block) + numOperations = checkedAdd(numOperations, static_cast(1)); + return checkedMultiply(numOperations, kOperationWeight); +} + +CrossbarUsage getComputeBodyCrossbarUsage(Region &body) { + CrossbarUsage crossbarUsage = 0; + for (auto &block : body) + for (auto &op : block) + if (isa(op)) + crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); + return crossbarUsage; +} + +bool isUsedAsWeightOnly(Operation *producerOp) { + if (producerOp->getNumResults() == 0) + return false; + for (Value result : producerOp->getResults()) { + if (result.use_empty()) + return false; + for (Operation *user : result.getUsers()) { + if (auto compute = dyn_cast(user)) { + if (!llvm::is_contained(compute.getWeights(), result)) + return false; + continue; + } + if (auto batch = dyn_cast(user)) { + if (!llvm::is_contained(batch.getWeights(), result)) + return false; + continue; + } + return false; + } + } + return true; +} + +std::vector aggregateEdges(llvm::ArrayRef edges) { + llvm::DenseMap, Weight> edgeWeights; + for (const ComputeGraphEdge &edge : edges) { + if (edge.source == edge.target) + continue; + auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost); + if (!inserted.second) + inserted.first->second = std::max(inserted.first->second, edge.transferCost); + } + + std::vector aggregatedEdges; + aggregatedEdges.reserve(edgeWeights.size()); + for (const auto &[key, weight] : edgeWeights) + aggregatedEdges.push_back({key.first, key.second, weight}); + llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) { + if (lhs.source != rhs.source) + return lhs.source < rhs.source; + return lhs.target < rhs.target; + }); + return aggregatedEdges; +} + +} // namespace + +size_t getBatchChunkTargetCount(int32_t laneCount) { + assert(laneCount > 0 && "laneCount must be positive"); + return std::min(static_cast(laneCount), std::max(1, getSchedulingCpuBudget())); +} + +ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + + size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount); + size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0); + return {batch.getOperation(), static_cast(laneStart), static_cast(laneCount)}; +} + +namespace { + +ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1); + + size_t chunkIndex = 0; + if (static_cast(lane) < largeChunkSpan) + chunkIndex = static_cast(lane) / (baseChunkSize + 1); + else + chunkIndex = largeChunkCount + (static_cast(lane) - largeChunkSpan) / baseChunkSize; + return getBatchChunkForIndex(batch, chunkIndex); +} + +} // namespace + +Weight getComputeInstanceWeight(const ComputeInstance &instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return getSpatComputeWeight(spatCompute); + auto batch = cast(instance.op); + return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast(instance.laneCount)); +} + +CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return getSpatComputeCrossbarUsage(spatCompute); + auto batch = cast(instance.op); + return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), + static_cast(instance.laneCount)); +} + +llvm::SmallVector getComputeInstanceInputs(const ComputeInstance &instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return llvm::SmallVector(spatCompute.getInputs().begin(), spatCompute.getInputs().end()); + auto batch = cast(instance.op); + llvm::SmallVector inputs; + inputs.reserve(instance.laneCount); + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) + inputs.push_back(batch.getInputs()[lane]); + return inputs; +} + +llvm::SmallVector getComputeInstanceWeights(const ComputeInstance &instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return llvm::SmallVector(spatCompute.getWeights().begin(), spatCompute.getWeights().end()); + auto batch = cast(instance.op); + llvm::SmallVector weights; + weights.reserve(instance.laneCount); + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) + weights.push_back(batch.getWeights()[lane]); + return weights; +} + +std::optional getComputeProducerInstance(Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return std::nullopt; + + while (auto extract = dyn_cast(op)) { + value = extract.getSource(); + op = value.getDefiningOp(); + if (!op) + return std::nullopt; + } + + if (auto spatCompute = dyn_cast(op)) + return ComputeInstance {spatCompute.getOperation(), 0, 1}; + if (auto batch = dyn_cast(op)) + return getBatchChunkForLane(batch, static_cast(cast(value).getResultNumber())); + return std::nullopt; +} + +ComputeGraph buildComputeGraph(Operation *entryOp) { + ComputeGraph graph; + + for (Region ®ion : entryOp->getRegions()) { + for (Block &block : region) { + for (Operation &op : block) { + if (auto spatCompute = dyn_cast(&op)) { + if (isUsedAsWeightOnly(spatCompute.getOperation())) + continue; + ComputeInstance instance {spatCompute.getOperation(), 0, 1}; + size_t index = graph.nodes.size(); + graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); + graph.instanceToIndex[instance] = index; + continue; + } + if (auto batch = dyn_cast(&op)) { + if (isUsedAsWeightOnly(batch.getOperation())) + continue; + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) { + ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex); + size_t index = graph.nodes.size(); + graph.nodes.push_back( + {instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); + graph.instanceToIndex[instance] = index; + } + } + } + } + } + + llvm::SmallVector rawEdges; + for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) { + for (Value input : getComputeInstanceInputs(node.instance)) { + auto producerInstance = getComputeProducerInstance(input); + if (!producerInstance) + continue; + auto producerIt = graph.instanceToIndex.find(*producerInstance); + if (producerIt == graph.instanceToIndex.end()) + continue; + rawEdges.push_back( + {producerIt->second, targetIndex, static_cast(getSizeInBytes(cast(input.getType())))}); + } + } + + std::vector aggregatedEdges = aggregateEdges(rawEdges); + graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end()); + graph.successors.assign(graph.nodes.size(), {}); + graph.predecessors.assign(graph.nodes.size(), {}); + for (const ComputeGraphEdge &edge : graph.edges) { + graph.successors[edge.source].push_back({edge.target, edge.transferCost}); + graph.predecessors[edge.target].push_back({edge.source, edge.transferCost}); + } + + return graph; +} + +bool verifyAcyclic(const ComputeGraph &graph) { + std::vector remainingParents(graph.nodes.size(), 0); + std::queue readyNodes; + for (size_t node = 0; node < graph.nodes.size(); ++node) { + remainingParents[node] = graph.predecessors[node].size(); + if (remainingParents[node] == 0) + readyNodes.push(node); + } + + size_t visited = 0; + while (!readyNodes.empty()) { + size_t node = readyNodes.front(); + readyNodes.pop(); + ++visited; + for (const auto &[child, weight] : graph.successors[node]) { + (void) weight; + assert(remainingParents[child] > 0 && "remaining parent count underflow"); + if (--remainingParents[child] == 0) + readyNodes.push(child); + } + } + + return visited == graph.nodes.size(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp new file mode 100644 index 0000000..a6ee020 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include +#include + +#include "../DCPGraph/Utils.hpp" +#include "ComputeInstance.hpp" + +namespace onnx_mlir { +namespace spatial { + +struct ComputeGraphNode { + ComputeInstance instance; + Weight weight = 0; + CrossbarUsage crossbarUsage = 0; + size_t originalOrder = 0; +}; + +struct ComputeGraphEdge { + size_t source = 0; + size_t target = 0; + Weight transferCost = 0; +}; + +struct ComputeGraph { + llvm::SmallVector nodes; + llvm::SmallVector edges; + std::vector>> successors; + std::vector>> predecessors; + llvm::DenseMap instanceToIndex; +}; + +ComputeGraph buildComputeGraph(mlir::Operation *entryOp); +bool verifyAcyclic(const ComputeGraph &graph); + +size_t getBatchChunkTargetCount(int32_t laneCount); +ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex); +std::optional getComputeProducerInstance(mlir::Value value); +llvm::SmallVector getComputeInstanceInputs(const ComputeInstance &instance); +llvm::SmallVector getComputeInstanceWeights(const ComputeInstance &instance); +Weight getComputeInstanceWeight(const ComputeInstance &instance); +CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance); + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp new file mode 100644 index 0000000..2d160d5 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "mlir/IR/Operation.h" + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" + +#include + +namespace onnx_mlir { +namespace spatial { + +struct ComputeInstance { + mlir::Operation *op = nullptr; + uint32_t laneStart = 0; + uint32_t laneCount = 1; + + bool operator==(const ComputeInstance &other) const { + return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount; + } +}; + +} // namespace spatial +} // namespace onnx_mlir + +using ComputeInstance = onnx_mlir::spatial::ComputeInstance; + +namespace llvm { +template <> +struct DenseMapInfo { + static onnx_mlir::spatial::ComputeInstance getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), UINT32_MAX, UINT32_MAX}; + } + static onnx_mlir::spatial::ComputeInstance getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; + } + static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) { + return llvm::hash_combine(value.op, value.laneStart, value.laneCount); + } + static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs, + const onnx_mlir::spatial::ComputeInstance &rhs) { + return lhs == rhs; + } +}; +} // namespace llvm diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp new file mode 100644 index 0000000..b4fac49 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp @@ -0,0 +1,62 @@ +#include "llvm/ADT/SmallVector.h" + +#include "DcpScheduler.hpp" +#include "../DCPGraph/Graph.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +namespace onnx_mlir { +namespace spatial { + +MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context) { + llvm::SmallVector nodeWeights; + llvm::SmallVector nodeCrossbarUsage; + llvm::SmallVector nodeOrderKeys; + llvm::SmallVector edges; + nodeWeights.reserve(graph.nodes.size()); + nodeCrossbarUsage.reserve(graph.nodes.size()); + nodeOrderKeys.reserve(graph.nodes.size()); + edges.reserve(graph.edges.size()); + + for (const ComputeGraphNode &node : graph.nodes) { + nodeWeights.push_back(node.weight); + nodeCrossbarUsage.push_back(node.crossbarUsage); + nodeOrderKeys.push_back(static_cast(node.originalOrder)); + } + for (const ComputeGraphEdge &edge : graph.edges) { + edges.push_back( + {static_cast(edge.source), static_cast(edge.target), static_cast(edge.transferCost)}); + } + + GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage); + if (coresCount.getValue() > 0) + graphDCP.setMaxCpuCount(static_cast(coresCount.getValue())); + graphDCP.setContext(context); + graphDCP.runDcp(); + + MergeScheduleResult result; + result.dominanceOrderCompute.reserve(graph.nodes.size()); + for (const ComputeGraphNode &node : graph.nodes) + result.dominanceOrderCompute.push_back(node.instance); + + for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) { + auto scheduledTasks = graphDCP.getScheduledTasks(cpu); + if (scheduledTasks.empty()) + continue; + + for (const auto &[slot, task] : llvm::enumerate(scheduledTasks)) { + const ComputeInstance instance = graph.nodes[task.nodeIndex].instance; + result.computeToCpuMap[instance] = cpu; + result.computeToCpuSlotMap[instance] = slot; + result.computeToAestMap[instance] = static_cast(task.aest); + } + + const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance; + result.cpuToLastComputeMap[cpu] = lastInstance; + result.isLastComputeOfCpu.insert(lastInstance); + } + + return result; +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp new file mode 100644 index 0000000..eeeeca8 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "mlir/IR/MLIRContext.h" + +#include "ComputeGraph.hpp" +#include "MergeSchedule.hpp" + +namespace onnx_mlir { +namespace spatial { + +MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context); + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp new file mode 100644 index 0000000..b941631 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" + +#include +#include +#include + +#include "ComputeInstance.hpp" + +namespace onnx_mlir { +namespace spatial { + +struct MergeScheduleResult { + std::vector dominanceOrderCompute; + llvm::DenseMap computeToCpuMap; + llvm::DenseMap computeToCpuSlotMap; + llvm::DenseMap computeToAestMap; + llvm::DenseSet isLastComputeOfCpu; + llvm::DenseMap cpuToLastComputeMap; +}; + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp new file mode 100644 index 0000000..135a8d6 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp @@ -0,0 +1,133 @@ +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" + +#include +#include + +#include "ComputeGraph.hpp" +#include "../DCPGraph/DCPAnalysis.hpp" +#include "DcpScheduler.hpp" +#include "MergeSchedulingAnalysis.hpp" +#include "PeftScheduler.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +namespace onnx_mlir { +namespace spatial { + +namespace { + +MergeSchedulerKind getSchedulerKind() { + switch (pimMergeScheduler.getValue()) { + case MergeSchedulerPeft: + return MergeSchedulerKind::Peft; + case MergeSchedulerDcp: + return MergeSchedulerKind::Dcp; + } + llvm_unreachable("unknown merge scheduler kind"); +} + +void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) { + llvm::DenseMap>> tasksByCpu; + tasksByCpu.reserve(result.cpuToLastComputeMap.size()); + + for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) { + const ComputeInstance instance = graph.nodes[nodeIndex].instance; + if (!result.computeToCpuMap.count(instance)) + llvm::report_fatal_error("merge scheduling: missing CPU assignment"); + if (!result.computeToCpuSlotMap.count(instance)) + llvm::report_fatal_error("merge scheduling: missing CPU slot assignment"); + if (!result.computeToAestMap.count(instance)) + llvm::report_fatal_error("merge scheduling: missing start time"); + + tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back( + {result.computeToCpuSlotMap.lookup(instance), nodeIndex}); + } + + for (auto &entry : tasksByCpu) { + auto &scheduledTasks = entry.second; + llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) { + if (lhs.first != rhs.first) + return lhs.first < rhs.first; + return lhs.second < rhs.second; + }); + + CrossbarUsage usedCrossbars = 0; + for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) { + if (scheduledTasks[slot].first != slot) + llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous"); + usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage); + if (usedCrossbars > crossbarCapacity) + llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded"); + } + + const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance; + auto lastIt = result.cpuToLastComputeMap.find(entry.first); + if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast)) + llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order"); + if (!result.isLastComputeOfCpu.count(expectedLast)) + llvm::report_fatal_error("merge scheduling: missing last-compute marker"); + } + + for (const ComputeGraphEdge &edge : graph.edges) { + const ComputeInstance source = graph.nodes[edge.source].instance; + const ComputeInstance target = graph.nodes[edge.target].instance; + const size_t sourceCpu = result.computeToCpuMap.lookup(source); + const size_t targetCpu = result.computeToCpuMap.lookup(target); + const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source); + const size_t targetSlot = result.computeToCpuSlotMap.lookup(target); + const Time sourceStart = static_cast