Compare commits
5 Commits
df703f0be9
...
0f13269040
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f13269040 | |||
| dafc1d15b7 | |||
| 3fa140be25 | |||
| 9fa850c140 | |||
| 25ade1bd63 |
@@ -19,19 +19,19 @@ pub mod utility;
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CoreInstructionsBuilder {
|
pub struct CoreInstructionsBuilder {
|
||||||
core_instructions: Vec<CoreInstruction>,
|
core_instructions: Vec<CoreInstructions>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CoreInstructionsBuilder {
|
impl CoreInstructionsBuilder {
|
||||||
pub fn new(size: usize) -> Self {
|
pub fn new(size: usize) -> Self {
|
||||||
let mut core_instructions = Vec::with_capacity(size);
|
let mut core_instructions = Vec::with_capacity(size);
|
||||||
for _ in 0..=size {
|
for _ in 0..=size {
|
||||||
core_instructions.push(CoreInstruction::empty());
|
core_instructions.push(CoreInstructions::empty());
|
||||||
}
|
}
|
||||||
Self { core_instructions }
|
Self { core_instructions }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(self) -> Vec<CoreInstruction> {
|
pub fn build(self) -> Vec<CoreInstructions> {
|
||||||
self.core_instructions
|
self.core_instructions
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,12 +43,12 @@ impl CoreInstructionsBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CoreInstruction {
|
pub struct CoreInstructions {
|
||||||
instructions: Instructions,
|
instructions: Instructions,
|
||||||
program_counter: usize,
|
program_counter: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CoreInstruction {
|
impl CoreInstructions {
|
||||||
fn new(instructions: Instructions, program_counter: usize) -> Self {
|
fn new(instructions: Instructions, program_counter: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
instructions,
|
instructions,
|
||||||
@@ -64,9 +64,9 @@ impl CoreInstruction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Instructions> for CoreInstruction {
|
impl From<Instructions> for CoreInstructions {
|
||||||
fn from(value: Instructions) -> Self {
|
fn from(value: Instructions) -> Self {
|
||||||
CoreInstruction {
|
CoreInstructions {
|
||||||
instructions: value,
|
instructions: value,
|
||||||
program_counter: 0,
|
program_counter: 0,
|
||||||
}
|
}
|
||||||
@@ -76,27 +76,27 @@ impl From<Instructions> for CoreInstruction {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Executable<'a> {
|
pub struct Executable<'a> {
|
||||||
cpu: CPU<'a>,
|
cpu: CPU<'a>,
|
||||||
core_instructions: Vec<CoreInstruction>,
|
core_instructions: Vec<CoreInstructions>,
|
||||||
send_recv: SendRecv,
|
send_recv: SendRecv,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_status(core_instructions: &[CoreInstruction]) {
|
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||||
for (i, core_instruction) in core_instructions.iter().enumerate() {
|
let mut tot_instructions = 0;
|
||||||
println!(
|
let mut progress = 0;
|
||||||
"Core {} : {}% ({}/{}) ",
|
for core_instruction in core_instructions.iter() {
|
||||||
i,
|
tot_instructions += core_instruction.instructions.len();
|
||||||
core_instruction.program_counter as f32 / core_instruction.instructions.len() as f32
|
progress += core_instruction.program_counter;
|
||||||
* 100.0,
|
|
||||||
core_instruction.program_counter,
|
|
||||||
core_instruction.instructions.len()
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
println!(
|
||||||
println!();
|
"Progress: {}% ({}/{}) ",
|
||||||
|
progress as f32 / tot_instructions as f32 * 100.0,
|
||||||
|
progress,
|
||||||
|
tot_instructions
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Executable<'a> {
|
impl<'a> Executable<'a> {
|
||||||
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstruction>) -> Executable<'a> {
|
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstructions>) -> Executable<'a> {
|
||||||
let num_core = cpu.num_core();
|
let num_core = cpu.num_core();
|
||||||
let send_recv = SendRecv::new(num_core);
|
let send_recv = SendRecv::new(num_core);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -117,21 +117,21 @@ impl<'a> Executable<'a> {
|
|||||||
{
|
{
|
||||||
let Self {
|
let Self {
|
||||||
cpu,
|
cpu,
|
||||||
core_instructions,
|
core_instructions: cores_instructions,
|
||||||
send_recv,
|
send_recv,
|
||||||
} = self;
|
} = self;
|
||||||
let mut cpu_progressed = 0;
|
let mut cpu_progressed = 0;
|
||||||
let max_core = cpu.num_core();
|
let max_core = cpu.num_core();
|
||||||
let mut index_unit = 0;
|
let mut cpu_index = 0;
|
||||||
let now = SystemTime::now();
|
let mut now = SystemTime::now();
|
||||||
|
|
||||||
while (cpu_progressed > -2) {
|
while (cpu_progressed > -2) {
|
||||||
let mut core_result = InstructionStatus::Completed;
|
let mut core_result = InstructionStatus::Completed;
|
||||||
while core_result.is_completed()
|
while core_result.is_completed()
|
||||||
&& let Some(core_instruction) = core_instructions.get_mut(index_unit)
|
&& let Some(core_instruction) = cores_instructions.get_mut(cpu_index)
|
||||||
{
|
{
|
||||||
core_result = InstructionStatus::NotExecuted;
|
core_result = InstructionStatus::NotExecuted;
|
||||||
let CoreInstruction {
|
let CoreInstructions {
|
||||||
instructions,
|
instructions,
|
||||||
program_counter,
|
program_counter,
|
||||||
} = core_instruction;
|
} = core_instruction;
|
||||||
@@ -144,21 +144,31 @@ impl<'a> Executable<'a> {
|
|||||||
cpu_progressed = 0;
|
cpu_progressed = 0;
|
||||||
*program_counter += 1;
|
*program_counter += 1;
|
||||||
}
|
}
|
||||||
|
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
||||||
|
print_status(&cores_instructions);
|
||||||
|
now = SystemTime::now();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if handle_send_recv(cpu, core_instructions, send_recv, core_result) {
|
handle_wait_sync(cpu, cores_instructions, core_result);
|
||||||
cpu_progressed = 0;
|
match handle_send_recv(cpu, cores_instructions, send_recv, core_result) {
|
||||||
}
|
(true, other_cpu_index) => {
|
||||||
handle_wait_sync(cpu, core_instructions, core_result);
|
cpu_progressed = 0;
|
||||||
index_unit = if index_unit + 1 >= max_core {
|
cpu_index = other_cpu_index;
|
||||||
cpu_progressed -= 1;
|
}
|
||||||
0
|
(false, 0) => {
|
||||||
} else {
|
cpu_index = if cpu_index + 1 >= cores_instructions.len() {
|
||||||
index_unit + 1
|
cpu_progressed -= 1;
|
||||||
};
|
0
|
||||||
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
} else {
|
||||||
print_status(&core_instructions);
|
cpu_index + 1
|
||||||
|
};
|
||||||
|
}
|
||||||
|
(false, other_cpu_index) => {
|
||||||
|
cpu_index = other_cpu_index;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
print_status(cores_instructions);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cpu(&self) -> &CPU<'a> {
|
pub fn cpu(&self) -> &CPU<'a> {
|
||||||
@@ -182,7 +192,7 @@ impl<'a> Executable<'a> {
|
|||||||
|
|
||||||
fn handle_wait_sync<'a, 'b, 'c>(
|
fn handle_wait_sync<'a, 'b, 'c>(
|
||||||
cpu: &'b mut CPU<'a>,
|
cpu: &'b mut CPU<'a>,
|
||||||
core_instructions: &'c mut [CoreInstruction],
|
core_instructions: &'c mut [CoreInstructions],
|
||||||
core_result: InstructionStatus,
|
core_result: InstructionStatus,
|
||||||
) where
|
) where
|
||||||
'a: 'b,
|
'a: 'b,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CoreInstruction, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
CoreInstructions, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
||||||
utility::add_offset_rd,
|
utility::add_offset_rd,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -43,14 +43,14 @@ impl SendRecv {
|
|||||||
|
|
||||||
pub fn handle_send_recv<'a, 'b >(
|
pub fn handle_send_recv<'a, 'b >(
|
||||||
cpu: &'b mut CPU<'a>,
|
cpu: &'b mut CPU<'a>,
|
||||||
core_instructions: & mut [CoreInstruction],
|
core_instructions: & mut [CoreInstructions],
|
||||||
send_recv: & mut SendRecv,
|
send_recv: & mut SendRecv,
|
||||||
core_result: InstructionStatus,
|
core_result: InstructionStatus,
|
||||||
) -> bool
|
) -> (bool, usize)
|
||||||
where 'a : 'b
|
where 'a : 'b
|
||||||
{
|
{
|
||||||
let transfer_memory = |cpu: &'b mut CPU<'a>,
|
let transfer_memory = |cpu: &'b mut CPU<'a>,
|
||||||
core_instructions: & mut [CoreInstruction],
|
core_instructions: & mut [CoreInstructions],
|
||||||
sender: Option<SendRecvInfo>,
|
sender: Option<SendRecvInfo>,
|
||||||
receiver: Option<SendRecvInfo>| {
|
receiver: Option<SendRecvInfo>| {
|
||||||
if let Some(sender) = sender
|
if let Some(sender) = sender
|
||||||
@@ -119,7 +119,7 @@ where 'a : 'b
|
|||||||
send_recv.sending[sender] = None;
|
send_recv.sending[sender] = None;
|
||||||
send_recv.receiving[receiver] = None;
|
send_recv.receiving[receiver] = None;
|
||||||
}
|
}
|
||||||
transfered
|
(transfered, receiver)
|
||||||
}
|
}
|
||||||
InstructionStatus::Reciving(instruction_data) => {
|
InstructionStatus::Reciving(instruction_data) => {
|
||||||
let (core_idx, imm_core) = instruction_data.get_core_immcore();
|
let (core_idx, imm_core) = instruction_data.get_core_immcore();
|
||||||
@@ -148,8 +148,8 @@ where 'a : 'b
|
|||||||
send_recv.sending[sender] = None;
|
send_recv.sending[sender] = None;
|
||||||
send_recv.receiving[receiver] = None;
|
send_recv.receiving[receiver] = None;
|
||||||
}
|
}
|
||||||
transfered
|
(transfered, sender)
|
||||||
}
|
}
|
||||||
_ => false,
|
_ => (false, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
|||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimMemory::allocateGatheredMemory() {
|
||||||
|
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||||
|
for (auto& [memEntry, value] : memEntries)
|
||||||
|
allocateMemoryForValue(value, memEntry);
|
||||||
|
}
|
||||||
|
|
||||||
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||||
memEntry.address = firstAvailableAddress;
|
memEntry.address = firstAvailableAddress;
|
||||||
firstAvailableAddress += memEntry.size;
|
firstAvailableAddress += memEntry.size;
|
||||||
@@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
// More than one SSA value per single global constant:
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
|
||||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
auto iter = globalConstants.find(globalMemrefOp);
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (iter == globalConstants.end())
|
if (inserted)
|
||||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
else {
|
else
|
||||||
MemEntry memEntry = *iter->second;
|
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
|
||||||
globalMemEntriesMap[getGlobalOp] = memEntry;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments())
|
for (mlir::Value arg : funcOp.getArguments())
|
||||||
gatherMemEntry(arg);
|
gatherMemEntry(arg);
|
||||||
|
|
||||||
allocateCore(funcOp);
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
gatherMemEntry(allocOp.getResult());
|
||||||
|
});
|
||||||
|
|
||||||
|
allocateGatheredMemory();
|
||||||
|
|
||||||
|
for (auto [alias, original] : globalAliases)
|
||||||
|
globalMemEntriesMap[alias] = getMemEntry(original);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateCore(Operation* op) {
|
void PimMemory::allocateCore(Operation* op) {
|
||||||
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
||||||
|
|
||||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
allocateGatheredMemory();
|
||||||
for (auto& [memEntry, value] : memEntries)
|
|
||||||
allocateMemoryForValue(value, memEntry);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||||
@@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
return std::to_string(size) + " Bytes";
|
return std::to_string(size) + " Bytes";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
|
SmallVector<unsigned, 8> indices;
|
||||||
|
auto addIndex = [&](unsigned weightIndex) {
|
||||||
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
|
indices.push_back(weightIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||||
|
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
|
llvm::sort(indices);
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
@@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
|||||||
|
|
||||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (hasWeightAlways(getGlobalOp))
|
if (hasWeightAlways(getGlobalOp))
|
||||||
return;
|
return;
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (!globalOp)
|
if (!globalOp)
|
||||||
return;
|
return;
|
||||||
|
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||||
|
return;
|
||||||
auto initialValue = globalOp.getInitialValue();
|
auto initialValue = globalOp.getInitialValue();
|
||||||
if (!initialValue)
|
if (!initialValue)
|
||||||
return;
|
return;
|
||||||
@@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp) {
|
if (!getGlobalOp) {
|
||||||
@@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
|
|
||||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||||
json::Array xbarsPerGroup;
|
json::Array xbarsPerGroup;
|
||||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
xbarsPerGroup.push_back(index);
|
xbarsPerGroup.push_back(index);
|
||||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||||
auto& fileName = mapWeightToFile[weight];
|
auto& fileName = mapWeightToFile[weight];
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class PimMemory {
|
|||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
MemEntry* gatherMemEntry(mlir::Value value);
|
MemEntry* gatherMemEntry(mlir::Value value);
|
||||||
|
void allocateGatheredMemory();
|
||||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ llvm::cl::opt<long> coresCount("core-count",
|
|||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||||
llvm::cl::init(-1));
|
llvm::cl::init(-1));
|
||||||
|
|
||||||
|
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||||
|
"dcp-critical-window-size",
|
||||||
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
|
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||||
|
llvm::cl::init(1024));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
ignoreConcatError("ignore-concat-error",
|
ignoreConcatError("ignore-concat-error",
|
||||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
|||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
|
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||||
|
|
||||||
// This option, by default set to false, will ignore an error when resolving a
|
// This option, by default set to false, will ignore an error when resolving a
|
||||||
// specific tiles of the operands of a concat. This specific case is when the
|
// specific tiles of the operands of a concat. This specific case is when the
|
||||||
|
|||||||
@@ -93,15 +93,22 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](PimCoreOp coreOp) {
|
||||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
auto annotateWeight = [&](unsigned weightIndex) {
|
||||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
if (weightIndex >= coreOp.getWeights().size())
|
||||||
if (isAlwaysWeight) {
|
return;
|
||||||
|
Value weight = coreOp.getWeights()[weightIndex];
|
||||||
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return;
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
markWeightAlways(getGlobalOp);
|
markWeightAlways(getGlobalOp);
|
||||||
markWeightAlways(globalMemrefOp);
|
markWeightAlways(globalMemrefOp);
|
||||||
}
|
};
|
||||||
|
|
||||||
|
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
||||||
|
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,18 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <map>
|
||||||
|
#include <numeric>
|
||||||
|
#include <optional>
|
||||||
|
#include <set>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -17,6 +25,361 @@ namespace spatial {
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct VirtualNode {
|
||||||
|
llvm::SmallVector<size_t, 4> originalComputeIndices;
|
||||||
|
Weight weight = 0;
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VirtualGraph {
|
||||||
|
std::vector<VirtualNode> nodes;
|
||||||
|
std::vector<IndexedEdge> edges;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TimingInfo {
|
||||||
|
std::vector<Time> aest;
|
||||||
|
std::vector<Time> alst;
|
||||||
|
std::vector<size_t> topologicalOrder;
|
||||||
|
bool valid = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WindowScheduleResult {
|
||||||
|
std::vector<std::vector<size_t>> mergeGroups;
|
||||||
|
bool usedAllAvailableCpus = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
|
for (auto [start, end, weight] : edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
if (startIndex == endIndex)
|
||||||
|
continue;
|
||||||
|
auto key = std::make_pair(startIndex, endIndex);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
auto it = edgeWeights.find(key);
|
||||||
|
if (it == edgeWeights.end())
|
||||||
|
edgeWeights.insert({key, edgeWeight});
|
||||||
|
else
|
||||||
|
it->second = std::max(it->second, edgeWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregatedEdges;
|
||||||
|
aggregatedEdges.reserve(edgeWeights.size());
|
||||||
|
for (auto [key, weight] : edgeWeights)
|
||||||
|
aggregatedEdges.push_back(
|
||||||
|
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||||
|
return aggregatedEdges;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||||
|
llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
VirtualGraph graph;
|
||||||
|
graph.nodes.reserve(spatWeightedComputes.size());
|
||||||
|
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||||
|
VirtualNode node;
|
||||||
|
node.originalComputeIndices.push_back(index);
|
||||||
|
node.weight = getSpatComputeWeight(spatWeightedCompute);
|
||||||
|
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute);
|
||||||
|
graph.nodes.push_back(std::move(node));
|
||||||
|
}
|
||||||
|
graph.edges = aggregateEdges(edges);
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||||
|
TimingInfo timing;
|
||||||
|
size_t nodeCount = graph.nodes.size();
|
||||||
|
timing.aest.assign(nodeCount, 0);
|
||||||
|
timing.alst.assign(nodeCount, 0);
|
||||||
|
timing.topologicalOrder.reserve(nodeCount);
|
||||||
|
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||||
|
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||||
|
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||||
|
children[startIndex].push_back({endIndex, edgeWeight});
|
||||||
|
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||||
|
incomingEdgeCount[endIndex]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> readyNodes;
|
||||||
|
readyNodes.reserve(nodeCount);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
if (incomingEdgeCount[i] == 0)
|
||||||
|
readyNodes.push_back(i);
|
||||||
|
|
||||||
|
size_t readyIndex = 0;
|
||||||
|
while (readyIndex != readyNodes.size()) {
|
||||||
|
size_t current = readyNodes[readyIndex++];
|
||||||
|
timing.topologicalOrder.push_back(current);
|
||||||
|
for (auto [child, weight] : children[current]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||||
|
incomingEdgeCount[child]--;
|
||||||
|
if (incomingEdgeCount[child] == 0)
|
||||||
|
readyNodes.push_back(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timing.topologicalOrder.size() != nodeCount)
|
||||||
|
return timing;
|
||||||
|
|
||||||
|
Time dcpl = 0;
|
||||||
|
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||||
|
Time maxParentAest = 0;
|
||||||
|
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||||
|
maxParentAest =
|
||||||
|
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||||
|
}
|
||||||
|
timing.aest[nodeIndex] = maxParentAest;
|
||||||
|
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||||
|
Time minAlst = std::numeric_limits<Time>::max();
|
||||||
|
if (children[nodeIndex].empty())
|
||||||
|
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||||
|
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||||
|
minAlst =
|
||||||
|
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||||
|
}
|
||||||
|
timing.alst[nodeIndex] = minAlst;
|
||||||
|
}
|
||||||
|
|
||||||
|
timing.valid = true;
|
||||||
|
return timing;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
|
||||||
|
std::vector<size_t> selected(timing.aest.size());
|
||||||
|
std::iota(selected.begin(), selected.end(), 0);
|
||||||
|
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
|
||||||
|
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||||
|
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||||
|
if (lhsSlack != rhsSlack)
|
||||||
|
return lhsSlack < rhsSlack;
|
||||||
|
if (timing.aest[lhs] != timing.aest[rhs])
|
||||||
|
return timing.aest[lhs] < timing.aest[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
});
|
||||||
|
selected.resize(std::min(windowSize, selected.size()));
|
||||||
|
return selected;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
|
||||||
|
std::vector<size_t> signature;
|
||||||
|
for (size_t nodeIndex : selectedNodes) {
|
||||||
|
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||||
|
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
|
||||||
|
}
|
||||||
|
std::sort(signature.begin(), signature.end());
|
||||||
|
return signature;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||||
|
std::vector<IndexedEdge> windowEdges;
|
||||||
|
windowEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||||
|
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||||
|
if (mappedStart == -1 || mappedEnd == -1)
|
||||||
|
continue;
|
||||||
|
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||||
|
}
|
||||||
|
return aggregateEdges(windowEdges);
|
||||||
|
}
|
||||||
|
|
||||||
|
WindowScheduleResult
|
||||||
|
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||||
|
std::vector<Weight> windowWeights;
|
||||||
|
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||||
|
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||||
|
windowWeights.reserve(selectedNodes.size());
|
||||||
|
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||||
|
|
||||||
|
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||||
|
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||||
|
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||||
|
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage);
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
|
windowGraph.setContext(context);
|
||||||
|
windowGraph.runDcp();
|
||||||
|
|
||||||
|
WindowScheduleResult result;
|
||||||
|
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
|
||||||
|
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.size() < 2)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
std::vector<size_t> mergeGroup;
|
||||||
|
mergeGroup.reserve(scheduledTasks.size());
|
||||||
|
for (const auto& task : scheduledTasks)
|
||||||
|
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||||
|
std::sort(mergeGroup.begin(), mergeGroup.end());
|
||||||
|
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool coarsenGraph(const VirtualGraph& graph,
|
||||||
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph& coarsenedGraph) {
|
||||||
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
||||||
|
if (mergeGroup.size() < 2)
|
||||||
|
continue;
|
||||||
|
for (size_t nodeIndex : mergeGroup) {
|
||||||
|
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||||
|
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
|
||||||
|
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
|
||||||
|
bool mergedAny = false;
|
||||||
|
coarsenedGraph.nodes.clear();
|
||||||
|
coarsenedGraph.edges.clear();
|
||||||
|
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
|
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||||
|
if (mergeGroupIndex == -1) {
|
||||||
|
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||||
|
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||||
|
if (newNodeIndex.has_value()) {
|
||||||
|
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualNode mergedNode;
|
||||||
|
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
|
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||||
|
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||||
|
memberNode.originalComputeIndices.end());
|
||||||
|
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||||
|
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||||
|
}
|
||||||
|
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
||||||
|
|
||||||
|
mergedAny = true;
|
||||||
|
newNodeIndex = coarsenedGraph.nodes.size();
|
||||||
|
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||||
|
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||||
|
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mergedAny)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> remappedEdges;
|
||||||
|
remappedEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||||
|
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||||
|
if (newStart == newEnd)
|
||||||
|
continue;
|
||||||
|
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||||
|
}
|
||||||
|
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||||
|
|
||||||
|
return computeTiming(coarsenedGraph).valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||||
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph& coarsenedGraph) {
|
||||||
|
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
|
||||||
|
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
|
||||||
|
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
|
||||||
|
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> acceptedMergeGroups;
|
||||||
|
acceptedMergeGroups.reserve(mergeGroups.size());
|
||||||
|
for (size_t groupIndex : orderedGroupIndices) {
|
||||||
|
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
|
||||||
|
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
|
||||||
|
|
||||||
|
VirtualGraph candidateGraph;
|
||||||
|
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
acceptedMergeGroups = std::move(candidateMergeGroups);
|
||||||
|
coarsenedGraph = std::move(candidateGraph);
|
||||||
|
}
|
||||||
|
return !acceptedMergeGroups.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
VirtualGraph graph;
|
||||||
|
graph.nodes.resize(computeCount);
|
||||||
|
graph.edges = aggregateEdges(edges);
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
if (timing.valid)
|
||||||
|
return timing.topologicalOrder;
|
||||||
|
|
||||||
|
std::vector<size_t> fallbackOrder(computeCount);
|
||||||
|
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
|
||||||
|
return fallbackOrder;
|
||||||
|
}
|
||||||
|
|
||||||
|
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||||
|
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||||
|
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
||||||
|
DCPAnalysisResult result;
|
||||||
|
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0);
|
||||||
|
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||||
|
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||||
|
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
||||||
|
|
||||||
|
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges);
|
||||||
|
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||||
|
for (size_t originalIndex : dominanceOrder) {
|
||||||
|
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex];
|
||||||
|
size_t cpu = originalToVirtualNode[originalIndex];
|
||||||
|
result.dominanceOrderCompute.push_back(spatWeightedCompute);
|
||||||
|
result.computeToCpuMap[spatWeightedCompute] = cpu;
|
||||||
|
result.cpuToLastComputeMap[cpu] = spatWeightedCompute;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
|
result.isLastComputeOfCpu.insert(lastCompute);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||||
|
llvm::ArrayRef<IndexedEdge> edges,
|
||||||
|
MLIRContext* context) {
|
||||||
|
GraphDCP graphDCP(spatWeightedComputes, edges);
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
|
graphDCP.setContext(context);
|
||||||
|
graphDCP.runDcp();
|
||||||
|
return graphDCP.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
@@ -31,8 +394,8 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult DCPAnalysis::run() {
|
DCPAnalysisResult DCPAnalysis::run() {
|
||||||
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||||
llvm::SmallVector<IndexedEdge, 10> edges;
|
SmallVector<IndexedEdge, 10> edges;
|
||||||
for (auto& region : entryOp->getRegions())
|
for (auto& region : entryOp->getRegions())
|
||||||
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
||||||
spatWeightedComputes.push_back(spatWeightedCompute);
|
spatWeightedComputes.push_back(spatWeightedCompute);
|
||||||
@@ -53,10 +416,37 @@ DCPAnalysisResult DCPAnalysis::run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GraphDCP graphDCP(spatWeightedComputes, edges);
|
|
||||||
graphDCP.setContext(entryOp->getContext());
|
if (dcpCriticalWindowSize.getValue() == 0)
|
||||||
graphDCP.runDcp();
|
return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext());
|
||||||
return graphDCP.getResult();
|
|
||||||
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges);
|
||||||
|
std::set<std::vector<size_t>> seenCriticalWindows;
|
||||||
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
|
if (!timing.valid)
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
|
||||||
|
if (selectedNodes.size() < 2)
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
|
||||||
|
break;
|
||||||
|
|
||||||
|
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||||
|
if (windowSchedule.mergeGroups.empty())
|
||||||
|
break;
|
||||||
|
|
||||||
|
VirtualGraph coarsenedGraph;
|
||||||
|
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
|
||||||
|
break;
|
||||||
|
virtualGraph = std::move(coarsenedGraph);
|
||||||
|
if (windowSchedule.usedAllAvailableCpus)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
// consumer land on different CPUs.
|
// consumer land on different CPUs.
|
||||||
//
|
//
|
||||||
// Output: an assignment of every task to a CPU and an order within that CPU,
|
// Output: an assignment of every task to a CPU and an order within that CPU,
|
||||||
// aiming to minimise the overall critical-path length (DCPL).
|
// aiming to minimize the overall critical-path length (DCPL).
|
||||||
//
|
//
|
||||||
// Every task keeps two timing estimates:
|
// Every task keeps two timing estimates:
|
||||||
// AEST - earliest start time, driven by parent completions + transfers.
|
// AEST - earliest start time, driven by parent completions + transfers.
|
||||||
@@ -16,9 +16,9 @@
|
|||||||
// Main loop (runDcp):
|
// Main loop (runDcp):
|
||||||
// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG.
|
// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG.
|
||||||
// 2. While there are ready tasks (all dependency parents scheduled):
|
// 2. While there are ready tasks (all dependency parents scheduled):
|
||||||
// a. Pick the candidate with tightest slack (earliest AEST breaks ties).
|
// a. Pick the candidate with the tightest slack (earliest AEST breaks ties).
|
||||||
// b. selectProcessor() tries every candidate CPU and picks the one that
|
// b. selectProcessor() tries every candidate CPU and picks the one that
|
||||||
// minimises a composite cost (own slot + smallest unscheduled child).
|
// minimizes a composite cost (own slot + the smallest unscheduled child).
|
||||||
// c. Commit the placement and refresh AEST/ALST.
|
// c. Commit the placement and refresh AEST/ALST.
|
||||||
// d. Release any child whose dependency parents are now all scheduled.
|
// d. Release any child whose dependency parents are now all scheduled.
|
||||||
//
|
//
|
||||||
@@ -43,7 +43,6 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
|
|||||||
Reference in New Issue
Block a user