19 Commits

Author SHA1 Message Date
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
NiccoloN e263e05f56 remove dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 18:32:40 +02:00
ilgeco 34c29fdec4 Materialize modification
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:22:13 +02:00
ilgeco aa088e2ba5 Verify fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:20:40 +02:00
NiccoloN 2836e759ab remove useless file
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:51:03 +02:00
NiccoloN 8071ebab0b faster refactored merge pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:50:19 +02:00
NiccoloN f1602c0550 add peft scheduling
Validate Operations / validate-operations (push) Has been cancelled
better deadlock report by pim simulator
2026-05-18 12:09:27 +02:00
NiccoloN de0a2f4561 remove useless guard in gemm lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:22:13 +02:00
NiccoloN 1c4a5bde76 compact softmax op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:14:59 +02:00
NiccoloN 78242e2887 compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 17:36:12 +02:00
NiccoloN fe244d5aa1 new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled
related fixes
2026-05-14 15:54:06 +02:00
NiccoloN d09e76c8f9 fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
2026-05-14 14:09:30 +02:00
NiccoloN c5e608fa5b replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled
better failure messages
2026-05-14 11:48:16 +02:00
ilgeco 43f3ccdd21 new yolo nodes with 100% more statics
Validate Operations / validate-operations (push) Has been cancelled
2026-05-14 10:47:31 +02:00
NiccoloN 8d95c604a6 automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 21:51:19 +02:00
NiccoloN 55eda487dc use seed in validate.py for deterministic tests 2026-05-13 21:49:36 +02:00
NiccoloN 061139aefb fix wrong send/receive reordering in post dcp merge instructions compaction 2026-05-13 21:48:49 +02:00
130 changed files with 4782 additions and 2639 deletions
+5 -2
View File
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
run only the codegen tail. run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and - `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count. per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum). - `--core-count=<N>` — number of cores. Required for PIM compilation.
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
merge-compute-nodes pass (default: `peft`).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy). - `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering. - `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`. - `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
``` ```
validate.py \ validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ --raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include --onnx-include-dir ../onnx-mlir/include \
--core-count 1000
``` ```
End-to-end network validation (example: first 4 layers of YOLOv11n): End-to-end network validation (example: first 4 layers of YOLOv11n):
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
.lock() .lock()
.unwrap() .unwrap()
.init(executor.cpu().num_core(), args.output.clone()); .init(executor.cpu().num_core(), args.output.clone());
executor.execute(); executor.execute()?;
dump_memory(executor, &args)?; dump_memory(executor, &args)?;
Ok(()) Ok(())
} }
@@ -1,5 +1,6 @@
#![allow(unused)] #![allow(unused)]
use anyhow::{Result, bail};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
send_recv: SendRecv, send_recv: SendRecv,
} }
struct DeadlockInfo {
cycle: String,
states: String,
}
fn print_status(core_instructions: &[CoreInstructions]) { fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0; let mut tot_instructions = 0;
let mut progress = 0; let mut progress = 0;
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
} }
} }
pub fn execute<'b>(&'b mut self) pub fn execute<'b>(&'b mut self) -> Result<()>
where where
'a: 'b, 'a: 'b,
{ {
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
} }
if (now.elapsed().unwrap() > Duration::from_secs(5)) { if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions); print_status(cores_instructions);
check_cycle(cpu, cores_instructions, send_recv); if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
now = SystemTime::now(); now = SystemTime::now();
} }
} }
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
} }
print_status(cores_instructions); print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
if cores_instructions
.iter()
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
{
bail!("Execution stalled with unfinished instructions");
}
#[cfg(feature = "profile_time")] #[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report(); TRACER.lock().unwrap().report();
Ok(())
} }
pub fn cpu(&self) -> &CPU<'a> { pub fn cpu(&self) -> &CPU<'a> {
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
} }
} }
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) { fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
enum CoreState { enum CoreState {
SendingTo(i32), SendingTo(i32, i32),
ReceivingFrom(i32), ReceivingFrom(i32, i32),
Working, Working,
Halted, Halted,
} }
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
let (this_core, target_core) = data.get_core_immcore(); let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) { if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core)); states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
} else if isa_send(functor_address) { } else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core)); states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
} else { } else {
states.insert(this_core, CoreState::Working); states.insert(this_core, CoreState::Working);
} }
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
for (&core_id, state) in states.iter() { for (&core_id, state) in states.iter() {
match state { match state {
CoreState::SendingTo(target_core) => { CoreState::SendingTo(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id) { if target_state != &CoreState::ReceivingFrom(core_id, *size) {
wait_for.insert(core_id, *target_core); wait_for.insert(core_id, *target_core);
} }
} }
CoreState::ReceivingFrom(target_core) => { CoreState::ReceivingFrom(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id) { if target_state != &CoreState::SendingTo(core_id, *size) {
wait_for.insert(core_id, *target_core); wait_for.insert(core_id, *target_core);
} }
} }
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(" -> "); .join(" -> ");
let cycle = cycle
.iter()
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source)
}
CoreState::Working => format!("core {} working", core),
CoreState::Halted => format!("core {} halted", core),
})
})
.collect::<Vec<_>>()
.join(", ");
println!("Fatal: Deadlock cycle detected: {}", cycle_msg); return Some(DeadlockInfo {
// bail!("Deadlock detected: {}", cycle_msg); cycle: cycle_msg,
break; // Stop tracing states: states_msg,
});
} }
// Hit a known branch that didn't result in a cycle // Hit a known branch that didn't result in a cycle
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
current_core = waiting_for; current_core = waiting_for;
} }
} }
None
} }
fn handle_wait_sync<'a, 'b, 'c>( fn handle_wait_sync<'a, 'b, 'c>(
+8
View File
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs)); return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
} }
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) { if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
+1
View File
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
mlir::arith::SubIOp, mlir::arith::SubIOp,
mlir::arith::MulIOp, mlir::arith::MulIOp,
mlir::arith::DivUIOp, mlir::arith::DivUIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp, mlir::arith::RemUIOp,
mlir::arith::IndexCastOp, mlir::arith::IndexCastOp,
mlir::memref::AllocOp, mlir::memref::AllocOp,
+1 -2
View File
@@ -1,7 +1,6 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
+24
View File
@@ -7,10 +7,34 @@
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error> #include <system_error>
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes. /// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription); mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
+1 -2
View File
@@ -1,8 +1,7 @@
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" #include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir { namespace onnx_mlir {
+1 -2
View File
@@ -1,10 +1,9 @@
#pragma once #pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <string> #include <string>
+36 -43
View File
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
os.write(bytes.data(), bytes.size()); os.write(bytes.data(), bytes.size());
} }
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
writeUint32LE(os, static_cast<uint32_t>(value));
}
inline void writeHeader(llvm::raw_ostream& os) { inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic)); os.write(kMagic, sizeof(kMagic));
@@ -186,39 +184,39 @@ inline Opcode opcodeFromString(llvm::StringRef opName) {
inline llvm::StringRef opcodeToString(Opcode opcode) { inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) { switch (opcode) {
case Opcode::nop: return "nop"; case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi"; case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld"; case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd"; case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub"; case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul"; case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi"; case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli"; case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw"; case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul"; case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd"; case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub"; case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul"; case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul"; case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax"; case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll"; case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra"; case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg"; case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu"; case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh"; case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm"; case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax"; case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv"; case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu"; case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl"; case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld"; case Opcode::ld: return "ld";
case Opcode::st: return "st"; case Opcode::st: return "st";
case Opcode::lldi: return "lldi"; case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv"; case Opcode::lmv: return "lmv";
case Opcode::send: return "send"; case Opcode::send: return "send";
case Opcode::recv: return "recv"; case Opcode::recv: return "recv";
case Opcode::wait: return "wait"; case Opcode::wait: return "wait";
case Opcode::sync: return "sync"; case Opcode::sync: return "sync";
} }
llvm_unreachable("Unsupported PIM binary opcode"); llvm_unreachable("Unsupported PIM binary opcode");
} }
@@ -235,9 +233,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
case Opcode::sldi: case Opcode::sldi:
case Opcode::saddi: case Opcode::saddi:
case Opcode::smuli: case Opcode::smuli:
case Opcode::lldi: case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
record.r2OrImm = getOptionalInt(instruction, "imm");
break;
case Opcode::mvmul: case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw"); record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu"); record.generic1 = getOptionalInt(instruction, "relu");
@@ -252,9 +248,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
record.r2OrImm = getOptionalInt(instruction, "core"); record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size"); record.generic3 = getOptionalInt(instruction, "size");
break; break;
default: default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
record.r2OrImm = getOptionalInt(instruction, "rs2");
break;
} }
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) { if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
@@ -371,8 +365,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
break; break;
case Opcode::wait: case Opcode::wait:
case Opcode::sync: case Opcode::sync:
case Opcode::nop: case Opcode::nop: break;
break;
} }
return instruction; return instruction;
+1 -1
View File
@@ -367,7 +367,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
instruction.generic1 = 0; instruction.generic1 = 0;
instruction.generic2 = 0; instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size); instruction.generic3 = static_cast<int32_t>(size);
(void)sizeFieldName; (void) sizeFieldName;
emitInstruction(instruction); emitInstruction(instruction);
} }
+22 -3
View File
@@ -1,5 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir { namespace onnx_mlir {
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen), llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen", pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256)); crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count", llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1)); llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize( llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size", "dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. " llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."), "Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
llvm::cl::init(4000)); llvm::cl::init(4000));
llvm::cl::opt<bool> llvm::cl::opt<bool>
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false)); llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir } // namespace onnx_mlir
+9
View File
@@ -20,8 +20,14 @@ typedef enum {
EmitPimCodegen = 3 EmitPimCodegen = 3
} PimEmissionTargetType; } PimEmissionTargetType;
typedef enum {
MergeSchedulerPeft = 0,
MergeSchedulerDcp = 1,
} PimMergeSchedulerType;
extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget; extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount; extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize; extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
// This option, by default set to false, will ignore an error when resolving a // This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the // specific tiles of the operands of a concat. This specific case is when the
// wanted tile is generated by two separate operands of the concat. If this is // wanted tile is generated by two separate operands of the concat. If this is
+1
View File
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm, PassManager& pm,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt) { std::string outputNameNoExt) {
verifyExplicitPimCoreCount();
if (pimOnlyCodegen) { if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM. // Skip all the lowering passes and directly generate code for PIM.
+52 -11
View File
@@ -33,7 +33,7 @@ struct DenseWeightView {
}; };
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews; SmallVector<Operation*> viewOps;
mlir::Value current = weight; mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp; memref::GetGlobalOp getGlobalOp;
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview)) if (!hasAllStaticSubviewParts(subview))
return failure(); return failure();
subviews.push_back(subview); viewOps.push_back(subview);
current = subview.getSource(); current = subview.getSource();
continue; continue;
} }
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
current = cast.getSource(); current = cast.getSource();
continue; continue;
} }
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure(); return failure();
} }
@@ -70,16 +88,39 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape); view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) { for (Operation* viewOp : llvm::reverse(viewOps)) {
SmallVector<int64_t> nextStrides; if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
nextStrides.reserve(subview.getStaticStrides().size()); SmallVector<int64_t> nextStrides;
for (auto [offset, stride, sourceStride] : nextStrides.reserve(subview.getStaticStrides().size());
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { for (auto [offset, stride, sourceStride] :
view.offset += offset * sourceStride; llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
nextStrides.push_back(stride * sourceStride); view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
} }
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
} }
return view; return view;
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
return tiles; return tiles;
} }
tensor::SplatOp Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType()); auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType(); Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length}; int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType); Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); auto buildBroadcast = [&](Value input) -> Value {
SmallVector<Value> index(oldType.getRank(), zero); auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult(); SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
return tensor::SplatOp::create(rewriter, loc, type, elementValue); if (isHostFoldableValue(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
mlir::ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc); mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length, int64_t length,
mlir::ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,8 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; }); return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
} }
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
}
static bool isStaticTensorResult(Operation* op) { static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) { return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type); auto shapedType = dyn_cast<ShapedType>(type);
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
}); });
} }
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) { static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second) if (!op || !visited.insert(op).second)
return false; return false;
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op)) if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true; return true;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
if (!isStaticTensorResult(op)) if (!isStaticTensorResult(op))
return false; return false;
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource()); return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op)) if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput()); return isHostFoldableValue(extractRowsOp.getInput());
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
return isHostFoldableOpImpl(op, visited); return isHostFoldableOpImpl(op, visited);
} }
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,5 +1,6 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op); bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -11,7 +12,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false; pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) { for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op)) if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
if (isHostFoldableOp(&op)) if (isHostFoldableOp(&op))
continue; continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); diagnostics.report(&op, [](Operation* illegalOp) {
hasFailure = true; illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
} }
return success(!hasFailure); diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -5,17 +5,15 @@
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "Common/Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
@@ -87,17 +85,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult); returnOp.setOperand(index, computeResult);
} }
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx); RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx); populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing"); moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) { if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure();
return;
}
RewritePatternSet matmulPatterns(ctx);
populateMatMulRewritePatterns(matmulPatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
bool hasUnloweredMatMul = false;
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
hasUnloweredMatMul = true;
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
});
if (hasUnloweredMatMul) {
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -130,31 +179,28 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet conversionPatterns(ctx); RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx); populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure(); signalPassFailure();
return; return;
} }
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
RewritePatternSet earlyPostPatterns(ctx); RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx); populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure(); signalPassFailure();
return; return;
} }
if (coresCount != -1) {
int computeOpsCount = 0;
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
}
PassManager cleanupPM(ctx); PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass()); cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp))) if (failed(cleanupPM.run(moduleOp)))
@@ -162,14 +208,29 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
RewritePatternSet postPatterns(ctx); RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx); populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) { if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -11,6 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); } static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
return collectComputeOp.getResult(0); return collectComputeOp.getResult(0);
} }
} // namespace static Value lowerSingleConvGroup(Value x,
Value w,
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, Value b,
ONNXConvOpAdaptor convOpAdaptor, RankedTensorType xType,
ConversionPatternRewriter& rewriter) const { RankedTensorType wType,
Location loc = convOp.getLoc(); RankedTensorType outType,
Value x = convOpAdaptor.getX(); int64_t padHeightBegin,
Value w = convOpAdaptor.getW(); int64_t padHeightEnd,
Value b = convOpAdaptor.getB(); int64_t padWidthBegin,
int64_t padWidthEnd,
auto xType = cast<RankedTensorType>(x.getType()); int64_t strideHeight,
auto wType = cast<RankedTensorType>(w.getType()); int64_t strideWidth,
auto outType = cast<RankedTensorType>(convOp.getY().getType()); int64_t dilationHeight,
int64_t dilationWidth,
if (!xType.hasStaticShape()) { ConversionPatternRewriter& rewriter,
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); Location loc) {
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0); const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1); const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2); const int64_t xHeight = xType.getDimSize(2);
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t outHeight = outType.getDimSize(2); const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3); const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position // A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue()); const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut); const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim); const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w); auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage: // Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
DenseElementsAttr biasDenseAttr; DenseElementsAttr biasDenseAttr;
if (hasB) { if (hasB) {
gemmBias = b; gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b); biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc); biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
} }
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr); const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
rewriter.getBoolAttr(false)) rewriter.getBoolAttr(false))
.getY(); .getY();
rewriter.replaceOp(convOp, return createCollectedConvOutput(ValueRange {gemmRows},
createCollectedConvOutput(ValueRange {gemmRows}, outType,
convOp.getType(), gemmOutType,
gemmOutType, nhwcType,
nhwcType, outType,
outType, numPatches,
numPatches, numChannelsOut,
numChannelsOut, effectiveMaxParallelPixels,
effectiveMaxParallelPixels, rewriter,
rewriter, loc);
loc)); }
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() < 1) {
convOp.emitOpError("requires group >= 1 for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
const int64_t group = convOp.getGroup();
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
if (numChannelsIn % group != 0) {
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
if (numChannelsOut % group != 0) {
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
const int64_t numChannelsInPerGroup = numChannelsIn / group;
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
if (wType.getDimSize(1) != numChannelsInPerGroup) {
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
return failure();
}
if (wType.getDimSize(0) != numChannelsOut) {
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
<< numChannelsOut << " for Spatial lowering";
return failure();
}
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
if (group == 1) {
rewriter.replaceOp(convOp,
lowerSingleConvGroup(x,
w,
b,
xType,
wType,
outType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
return success();
}
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
SmallVector<Value> bSlices;
if (hasB) {
auto biasType = cast<RankedTensorType>(b.getType());
int64_t biasAxis = -1;
if (biasType.getRank() == 1)
biasAxis = 0;
else if (biasType.getRank() == 2)
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
else {
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
<< biasType.getRank();
return failure();
}
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
}
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
return failure();
}
SmallVector<Value> groupResults;
groupResults.reserve(group);
auto groupOutType =
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
for (int64_t groupId = 0; groupId < group; groupId++) {
Value groupX = xSlices[groupId];
Value groupW = wSlices[groupId];
Value groupB = hasB ? bSlices[groupId] : noBias;
groupResults.push_back(lowerSingleConvGroup(groupX,
groupW,
groupB,
cast<RankedTensorType>(groupX.getType()),
cast<RankedTensorType>(groupW.getType()),
groupOutType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
});
result = concatCompute.getResult(0);
}
rewriter.replaceOp(convOp, result);
return success(); return success();
} }
@@ -502,9 +502,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
} }
(void) bType; (void) bType;
if (!isHostFoldableValue(b))
return failure();
Value sharedBias; Value sharedBias;
if (hasC) { if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
@@ -2,8 +2,12 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <functional>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
} }
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
if (rhsBatchShape.empty())
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
return failure();
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
}
static Value extractBatchMatrix(Value value, static Value extractBatchMatrix(Value value,
int64_t batchIndex, int64_t batchIndex,
int64_t batchSize, int64_t batchSize,
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType()); auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape(); auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) { if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0})); perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
} }
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1})); return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
} }
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape()) || !outType.hasStaticShape())
return failure(); return failure();
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3) if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure(); return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape())) || !haveStaticPositiveShape(outType.getShape()))
return failure(); return failure();
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1; SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1; SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
const int64_t batch = std::max(lhsBatch, rhsBatch); auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
return failure(); return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0); const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1); const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0); const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1); const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
if (k != rhsK) if (k != rhsK)
return failure(); return failure();
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure(); return failure();
} }
else { else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n) SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|| outType.getDimSize(outType.getRank() - 1) != n)
return failure(); return failure();
} }
Location loc = matmulOp.getLoc(); Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value lhs = matmulOp.getA(); Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = matmulOp.getB(); Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
int64_t lhsBatchForGemm = lhsBatch; int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch; int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m; int64_t gemmM = m;
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
} }
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();
} }
@@ -1,9 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
return permutedShape; return permutedShape;
} }
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
ArrayRef<Value> outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = inputType.getRank();
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
sliceShape.push_back(inputType.getDimSize(rank - 1));
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
offsets.reserve(rank);
sizes.reserve(rank);
for (Value outerIndex : outerIndices) {
offsets.push_back(outerIndex);
sizes.push_back(rewriter.getIndexAttr(1));
}
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
}
static Value buildLoopSoftmaxNest(Value input,
Value accumulator,
RankedTensorType inputType,
int64_t axis,
SmallVectorImpl<Value>& outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
Value loopIndex = loop.getInductionVar();
Value loopAccumulator = loop.getRegionIterArgs().front();
outerIndices.push_back(loopIndex);
Value updatedAccumulator =
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
outerIndices.pop_back();
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
rewriter.setInsertionPointAfter(loop);
return loop.getResult(0);
}
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); if (inputType.getRank() == 1) {
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
spatial::SpatYieldOp::create(rewriter, loc, softmax);
return;
}
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
SmallVector<Value> outerIndices;
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, result);
}); });
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc);
if (axis == softmaxAxis)
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> rebuiltSlices;
rebuiltSlices.reserve(slices.size());
for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return concatValues(rebuiltSlices, axis, rewriter, loc);
}
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value input = adaptor.getInput(); Value input = adaptor.getInput();
Value result; Value result;
if (axis == inputType.getRank() - 1) { if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
} }
else { else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
}); });
Value transposedInput = preTransposeCompute.getResult(0); Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax( Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
auto postTransposeCompute = auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create( Value transposed = ONNXTransposeOp::create(
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
} }
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> { struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
}); });
if (sourceType.getNumElements() != resultType.getNumElements())
return failure();
return replaceWithReshape([&](Value data) -> Value {
Value reshaped = data;
if (sourceType.getRank() != 1) {
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
reshaped = tensor::CollapseShapeOp::create(
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
}
if (resultType.getRank() == 1)
return reshaped;
return tensor::ExpandShapeOp::create(
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
.getResult();
});
return failure(); return failure();
} }
}; };
@@ -1,10 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -15,42 +15,88 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value static Value buildNearestAsymmetricIndex(
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
SmallVector<OpFoldResult> sizes; Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1)); Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
sizes.reserve(inputType.getRank()); Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
for (int64_t dim : inputType.getShape()) return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
} }
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) { static Value buildNearestResizeLoop(Value input,
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1); RankedTensorType inputType,
} RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = resultType.getElementType();
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
static Value buildNearestResize(Value input, SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
ArrayRef<int64_t> inputShape, SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
ArrayRef<int64_t> outputShape,
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == static_cast<int64_t>(outputShape.size()))
return input;
SmallVector<Value> slices; Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
slices.reserve(outputShape[axis]); Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) { Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]); Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc); Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
}
return createSpatConcat(rewriter, loc, axis, slices); Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(batchLoop.getBody());
Value outputN = batchLoop.getInductionVar();
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
rewriter.setInsertionPointToStart(channelLoop.getBody());
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
Value updatedOutput =
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
scf::YieldOp::create(rewriter, loc, updatedOutput);
rewriter.setInsertionPointAfter(widthLoop);
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
rewriter.setInsertionPointAfter(heightLoop);
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
rewriter.setInsertionPointAfter(channelLoop);
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
rewriter.setInsertionPointAfter(batchLoop);
return batchLoop.getResult(0);
} }
struct Resize : OpConversionPattern<ONNXResizeOp> { struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType()); auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType()); auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
if (inputType.getRank() != 4 || resultType.getRank() != 4)
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return failure(); return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure(); return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
auto computeOp = auto computeOp =
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
Value result = Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
}); });
rewriter.replaceOp(resizeOp, computeOp.getResults()); rewriter.replaceOp(resizeOp, computeOp.getResults());
@@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp()); return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
} }
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= block.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
continue;
return true;
}
return false;
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily. // Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> { struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern; using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
@@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
}); });
} }
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,8 +3,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
patterns.add<convAddToConvWithBiasLeft>(ctx); patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx); patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx); patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx); patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
auto entryFunc = getPimEntryFunc(module); auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) { if (failed(entryFunc)) {
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -138,12 +138,13 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
} }
void SpatialToPimPass::runOnOperation() { void SpatialToPimPass::runOnOperation() {
coreId = 1; coreId = 0;
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) { if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -169,26 +170,22 @@ void SpatialToPimPass::runOnOperation() {
spatial::SpatChannelSendTensorBatchOp, spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>(); spatial::SpatExtractRowsOp>();
{ RewritePatternSet initialPatterns(ctx);
RewritePatternSet patterns(ctx); populateWithGenerated(initialPatterns);
populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure();
signalPassFailure(); return;
return;
}
} }
{ RewritePatternSet globalTensorPatterns(ctx);
RewritePatternSet patterns(ctx); populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
populateGlobalTensorMaterializationPatterns(patterns); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors); addReturnOutputBuffers(returnOp, rewriter, outputTensors);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -197,6 +194,7 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -205,17 +203,16 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) { for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp); markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure(); signalPassFailure();
return; return;
} }
} }
{ RewritePatternSet initialTensorPackingPatterns(ctx);
RewritePatternSet patterns(ctx); populateTensorPackingPatterns(initialTensorPackingPatterns);
populateTensorPackingPatterns(patterns); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
walkAndApplyPatterns(funcOp, std::move(patterns)); eraseUnusedTensorPackingOps(funcOp, rewriter);
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveOp> receiveOps; SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>()) for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
@@ -229,27 +226,27 @@ void SpatialToPimPass::runOnOperation() {
} }
} }
{ RewritePatternSet coreBodyPatterns(ctx);
RewritePatternSet coreBodyPatterns(ctx); populateWithGenerated(coreBodyPatterns);
populateWithGenerated(coreBodyPatterns); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps; SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) { for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure(); coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
return; signalPassFailure();
} return;
} }
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps; SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) { for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure(); coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
return; signalPassFailure();
} return;
} }
} }
@@ -259,44 +256,43 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) { if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure(); signalPassFailure();
return; return;
} }
{ RewritePatternSet finalTensorPackingPatterns(ctx);
RewritePatternSet patterns(ctx); populateTensorPackingPatterns(finalTensorPackingPatterns);
populateTensorPackingPatterns(patterns); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
walkAndApplyPatterns(funcOp, std::move(patterns)); eraseUnusedTensorPackingOps(funcOp, rewriter);
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
{ ConversionTarget communicationTarget(*ctx);
ConversionTarget communicationTarget(*ctx); communicationTarget.addLegalDialect<PimDialect,
communicationTarget.addLegalDialect<PimDialect, tensor::TensorDialect,
tensor::TensorDialect, arith::ArithDialect,
arith::ArithDialect, bufferization::BufferizationDialect,
bufferization::BufferizationDialect, func::FuncDialect,
func::FuncDialect, memref::MemRefDialect,
memref::MemRefDialect, scf::SCFDialect,
scf::SCFDialect, BuiltinDialect>();
BuiltinDialect>(); communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addLegalOp<ModuleOp>(); communicationTarget.addIllegalOp<spatial::SpatConcatOp,
communicationTarget.addIllegalOp<spatial::SpatConcatOp, spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelSendOp,
spatial::SpatChannelSendOp, spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorOp, spatial::SpatExtractRowsOp>();
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx); RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns); populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
signalPassFailure(); funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
return; signalPassFailure();
} return;
} }
if (failed(verifySpatialToPimBoundary(moduleOp))) { if (failed(verifySpatialToPimBoundary(moduleOp))) {
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -1,5 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -75,16 +74,14 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
return failure(); return failure();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType()); auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
auto newConcat = pim::PimConcatOp::create(rewriter, auto newConcat = pim::PimConcatOp::create(
concatOp.getLoc(), rewriter,
concatOp.getOutput().getType(), concatOp.getLoc(),
concatOp.getAxisAttr(), concatOp.getOutput().getType(),
ValueRange(packedInputs), concatOp.getAxisAttr(),
tensor::EmptyOp::create(rewriter, ValueRange(packedInputs),
concatOp.getLoc(), tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
outputType.getShape(), .getResult());
outputType.getElementType())
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput()); rewriter.replaceOp(concatOp, newConcat.getOutput());
return success(); return success();
} }
@@ -1,7 +1,7 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
return WalkResult::skip(); return WalkResult::skip();
}); });
if (hasFailed) { if (hasFailed) {
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -1,15 +1,16 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include <limits> #include <limits>
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -29,9 +30,8 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
} }
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp, static FailureOr<uint64_t>
Block& body, getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp); uint64_t endInstruction = opOrder.lookup(allocOp);
SmallPtrSet<Operation*, 16> visited; SmallPtrSet<Operation*, 16> visited;
SmallVector<Value> pendingValues; SmallVector<Value> pendingValues;
@@ -45,9 +45,15 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
if (!visited.insert(user).second) if (!visited.insert(user).second)
continue; continue;
if (isSupportedAliasOp(user)) { if (isSupportedAliasOp(user))
for (Value result : user->getResults()) for (Value result : user->getResults())
pendingValues.push_back(result); pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
if (initArg == value)
pendingValues.push_back(forOp.getResult(index));
}
} }
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) { if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
@@ -2,7 +2,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
CoalescingReportRow row; CoalescingReportRow row;
}; };
static std::string formatMemory(uint64_t bytes) { static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
return formatReportMemory(bytes);
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName); auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) { static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = { llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)}, {"Number of candidates", std::to_string(row.numCandidates)},
{"Skipped allocations", std::to_string(row.numSkipped)}, {"Skipped allocations", std::to_string(row.numSkipped) },
{"Removed allocations", std::to_string(row.numRemoved)}, {"Removed allocations", std::to_string(row.numRemoved) },
{"Saved memory", formatMemory(row.savedBytes)}}; {"Saved memory", formatMemory(row.savedBytes) }
};
printReportFlatFields(os, fields); printReportFlatFields(os, fields);
} }
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
totalRow.savedBytes += entryTotal.savedBytes; totalRow.savedBytes += entryTotal.savedBytes;
} }
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)}, llvm::SmallVector<ReportField, 4> totalFields = {
{"Skipped allocations", std::to_string(totalRow.numSkipped)}, {"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Removed allocations", std::to_string(totalRow.numRemoved)}, {"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Saved memory", formatMemory(totalRow.savedBytes)}}; {"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportTotalsBlock(os, totalFields); printReportTotalsBlock(os, totalFields);
if (!entries.empty()) if (!entries.empty())
os << "\n"; os << "\n";
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) { if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
llvm::SmallVector<ReportField, 4> perCoreFields = { llvm::SmallVector<ReportField, 4> perCoreFields = {
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)}, {"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)}, {"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)}, {"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}}; {"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
};
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]); CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
llvm::SmallVector<ReportField, 4> totalFields = { llvm::SmallVector<ReportField, 4> totalFields = {
{"Number of candidates", std::to_string(totalRow.numCandidates)}, {"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)}, {"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Removed allocations", std::to_string(totalRow.numRemoved)}, {"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes)}}; {"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields); printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
} }
else { else {
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
} // namespace } // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
return std::make_unique<StaticMemoryCoalescingPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir
+7
View File
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
SpatialOpsVerify.cpp SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
+20 -3
View File
@@ -3,6 +3,7 @@
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
return success(); return success();
} }
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute" );
}
return success();
}
LogicalResult SpatCompute::verify() { LogicalResult SpatCompute::verify() {
auto& block = getBody().front(); auto& block = getBody().front();
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
for (auto arg : block.getArguments()) for (auto arg : block.getArguments())
if (arg.use_empty()) if (arg.use_empty())
return emitError("ComputeOp block argument is not used"); return emitError("ComputeOp block argument is not used");
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return success(); return success();
} }
@@ -465,8 +480,8 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch coreIds attribute must be a dense i32 array"); return emitError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz)) if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount"); return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be positive"); return emitError("compute_batch coreIds values must be non-negative");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds; llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef()) for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second) if (!seenCoreIds.insert(coreId).second)
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("body block argument type must match input type"); return emitError("body block argument type must match input type");
} }
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
} }
@@ -1,802 +1,19 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdlib>
#include <iterator>
#include <numeric>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
#include "Graph.hpp" #include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
using namespace mlir;
namespace {
using SpatCompute = onnx_mlir::spatial::SpatCompute;
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
}
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : body)
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0;
for (auto& block : body)
for (auto& op : block)
if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
}
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
inputs.push_back(batch.getInputs()[lane]);
return inputs;
}
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
Operation* op = value.getDefiningOp();
if (!op)
return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto spatCompute = dyn_cast<SpatCompute>(op))
return ComputeInstance {spatCompute.getOperation(), 0, 1};
if (auto batch = dyn_cast<SpatComputeBatch>(op))
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
return std::nullopt;
}
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
SmallVector<ComputeInstance> instances;
auto isUsedAsWeightOnly = [](Operation* producerOp) {
if (producerOp->getNumResults() == 0)
return false;
for (Value result : producerOp->getResults()) {
if (result.use_empty())
return false;
for (Operation* user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result))
return false;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
if (!llvm::is_contained(batch.getWeights(), result))
return false;
continue;
}
return false;
}
}
return true;
};
for (Region& region : entryOp->getRegions()) {
for (Block& block : region) {
for (Operation& op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue;
instances.push_back({spatCompute.getOperation(), 0, 1});
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (isUsedAsWeightOnly(batch.getOperation()))
continue;
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
}
}
}
}
return instances;
}
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(computeInstances.size());
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
VirtualNode node;
node.originalComputeIndices.push_back(index);
node.weight = getComputeInstanceWeight(computeInstance);
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
graph.nodes.push_back(std::move(node));
}
graph.edges = aggregateEdges(edges);
return graph;
}
TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode& node = graph.nodes[nodeIndex];
if (!node.originalComputeIndices.empty())
return node.originalComputeIndices.front();
return nodeIndex;
};
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
size_t lhsKey = getVirtualNodeOrderKey(lhs);
size_t rhsKey = getVirtualNodeOrderKey(rhs);
if (lhsKey != rhsKey)
return lhsKey > rhsKey;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push(i);
while (!readyNodes.empty()) {
size_t current = readyNodes.top();
readyNodes.pop();
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto& neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
windowNodeOrderKeys.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
}
GraphDCP windowGraph(
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
if (coresCount.getValue() > 0)
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph,
std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
memberNode.originalComputeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
}
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid) {
virtualNodeOrder = std::move(timing.topologicalOrder);
}
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalComputeToCpu[originalIndex] = cpu;
}
result.dominanceOrderCompute.reserve(computeInstances.size());
llvm::DenseMap<size_t, size_t> nextCpuSlot;
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
size_t cpu = originalComputeToCpu[originalIndex];
result.dominanceOrderCompute.push_back(computeInstance);
result.computeToCpuMap[computeInstance] = cpu;
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
result.computeToAestMap[computeInstance] = originalIndex;
result.cpuToLastComputeMap[cpu] = computeInstance;
}
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
return result;
}
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
ComputeInstance instance = computeInstances[task.nodeIndex];
result.computeToCpuMap[instance] = cpu;
result.computeToCpuSlotMap[instance] = slot;
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
}
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
}
return result;
}
DCPAnalysisResult
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
SmallVector<Weight> nodeWeights;
SmallVector<CrossbarUsage> nodeCrossbarUsage;
SmallVector<int64_t> nodeOrderKeys;
nodeWeights.reserve(computeInstances.size());
nodeCrossbarUsage.reserve(computeInstances.size());
nodeOrderKeys.reserve(computeInstances.size());
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
nodeWeights.push_back(getComputeInstanceWeight(instance));
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
nodeOrderKeys.push_back(static_cast<int64_t>(index));
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
graphDCP.runDcp();
return buildResultFromScheduledGraph(graphDCP, computeInstances);
}
} // namespace
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = dyn_cast<SpatCompute>(op))
return res;
return {};
}
DCPAnalysisResult DCPAnalysis::run() { DCPAnalysisResult DCPAnalysis::run() {
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp); ComputeGraph graph = buildComputeGraph(entryOp);
SmallVector<IndexedEdge, 10> edges; DcpScheduleOptions options;
if (coresCount.getValue() > 0)
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex; options.processorCount = static_cast<size_t>(coresCount.getValue());
instanceToIndex.reserve(computeInstances.size()); options.criticalWindowSize = dcpCriticalWindowSize.getValue();
for (auto [index, instance] : llvm::enumerate(computeInstances)) options.allowFallbackForAutoCoreCount = true;
instanceToIndex[instance] = index; return runDcpScheduler(graph, options, entryOp->getContext());
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
for (Value input : getComputeInstanceInputs(computeInstance)) {
if (auto producerInstance = getOriginalComputeInstance(input)) {
auto producerIt = instanceToIndex.find(*producerInstance);
assert(producerIt != instanceToIndex.end());
auto indexStartEdge = producerIt->second;
edges.push_back({static_cast<int64_t>(indexStartEdge),
static_cast<int64_t>(indexEndEdge),
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
}
}
if (coresCount.getValue() > 0) {
size_t schedulingCpuBudget = getSchedulingCpuBudget();
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
});
if (needsExactScheduledBatches)
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
}
if (dcpCriticalWindowSize.getValue() == 0)
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) {
if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
}
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
}
SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
} }
} // namespace spatial } // namespace spatial
@@ -2,64 +2,27 @@
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h" #include "../Scheduling/MergeSchedule.hpp"
#include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <vector>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
// A scheduling identity that covers both spat.compute and scheduled shards of
// spat.compute_batch.
struct ComputeInstance {
mlir::Operation* op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
struct DCPAnalysisResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
};
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis { struct DCPAnalysis {
private: private:
DCPAnalysisResult result; DCPAnalysisResult result;
mlir::Operation* entryOp; mlir::Operation *entryOp;
DCPAnalysisResult run(); DCPAnalysisResult run();
public: public:
DCPAnalysis(mlir::Operation* op) DCPAnalysis(mlir::Operation *op)
: entryOp(op) { : entryOp(op) {
result = run(); result = run();
} }
DCPAnalysisResult& getResult() { return result; } DCPAnalysisResult &getResult() { return result; }
}; };
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
namespace llvm { using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
template <>
struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
};
} // namespace llvm
@@ -0,0 +1,636 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <utility>
#include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
using SpatCompute = spatial::SpatCompute;
using ProducerValueRef = spatial::ProducerValueRef;
using spatial::getComputeInstanceInputs;
using spatial::getComputeInstanceOutputTypes;
using spatial::getComputeInstanceOutputValues;
using spatial::getComputeInstanceTemplateBlock;
using spatial::getComputeInstanceWeights;
using spatial::getProducerValueRef;
class MergeScheduleMaterializerImpl {
public:
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
schedule = &scheduleResult;
nextChannelId = &nextChannelIdRef;
collectScheduledTasks();
buildTaskIndex();
collectExternalInputsAndWeights();
planRemoteChannels();
planReceiveReordering();
createCpuComputeOps();
if (failed(cloneTaskBodies()))
return failure();
replaceExternalUses();
if (failed(eraseOldScheduledOps()))
return failure();
moveExternalUsersBeforeReturn();
return success();
}
private:
struct ScheduledTask {
ComputeInstance computeInstance;
size_t cpu = 0;
size_t orderWithinCpu = 0;
};
struct ChannelInfo {
int64_t channelId = -1;
int32_t sourceCoreId = -1;
int32_t targetCoreId = -1;
};
struct CpuProgram {
SpatCompute op;
DenseMap<Value, Value> externalInputMap;
DenseMap<Value, size_t> weightToIndex;
};
struct RemoteSendInfo {
ChannelInfo channelInfo;
ComputeInstance consumer;
size_t inputIndex = 0;
size_t consumerOrder = 0;
size_t sourceOrder = 0;
};
struct RemoteReceiveEntry {
ChannelInfo channelInfo;
ComputeInstance consumer;
size_t inputIndex = 0;
size_t sourceOrder = 0;
};
static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
| static_cast<uint32_t>(channelInfo.targetCoreId);
}
void collectExternalUsers(Operation* op) {
if (!externalUsersToMove.insert(op).second)
return;
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
continue;
collectExternalUsers(user);
}
}
}
void collectScheduledTasks() {
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
oldComputeOps.insert(scheduledInstance.op);
scheduledTasks.push_back({scheduledInstance,
schedule->computeToCpuMap.lookup(scheduledInstance),
schedule->computeToCpuSlotMap.lookup(scheduledInstance)});
}
}
void buildTaskIndex() {
auto markCpuSeen = [&](size_t cpu) {
if (seenCpus.insert(cpu).second)
orderedCpus.push_back(cpu);
};
for (const ScheduledTask& task : scheduledTasks) {
taskByComputeInstance[task.computeInstance] = task;
tasksByCpu[task.cpu].push_back(task);
markCpuSeen(task.cpu);
}
llvm::sort(orderedCpus);
for (size_t cpu : orderedCpus)
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
return lhs.orderWithinCpu < rhs.orderWithinCpu;
});
}
void collectExternalInputsAndWeights() {
for (size_t cpu : orderedCpus) {
for (const ScheduledTask& task : tasksByCpu[cpu]) {
auto& thisCpuWeights = cpuWeights[cpu];
auto& thisSeenWeights = seenWeightsByCpu[cpu];
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
for (Value weight : taskWeights)
if (thisSeenWeights.insert(weight).second)
thisCpuWeights.push_back(weight);
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
remoteInputs.resize(taskInputs.size());
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto producerRef = getProducerValueRef(input);
if (producerRef) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu != cpu) {
ChannelInfo info {
(*nextChannelId)++,
static_cast<int32_t>(producerIt->second.cpu),
static_cast<int32_t>(cpu),
};
remoteInputs[inputIndex] = info;
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
perResultChannels[producerRef->resultIndex].push_back(
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
}
continue;
}
}
if (seenExternalInputsByCpu[cpu].insert(input).second)
cpuExternalInputs[cpu].push_back(input);
}
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false;
for (auto& use : output.getUses()) {
Operation* useOwner = use.getOwner();
if (oldComputeOps.contains(useOwner))
continue;
hasExternalUser = true;
if (!isa<func::ReturnOp>(useOwner))
collectExternalUsers(useOwner);
}
if (hasExternalUser)
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
}
}
}
}
void planRemoteChannels() {
for (size_t cpu : orderedCpus) {
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
for (const ScheduledTask& task : tasksByCpu[cpu]) {
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
if (sendsIt == remoteSendsByTask.end())
continue;
for (auto& sendInfos : sendsIt->second) {
for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
if (!inserted) {
if (sendInfo.consumerOrder < it->second)
pairsNeedingReceiveReorder.insert(pairKey);
it->second = sendInfo.consumerOrder;
}
}
}
}
}
}
void planReceiveReordering() {
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
for (auto& taskSends : remoteSendsByTask) {
for (auto& sendInfos : taskSends.second) {
for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey))
reorderedSendsByPair[pairKey].push_back(&sendInfo);
}
}
}
for (auto& pairSends : reorderedSendsByPair) {
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
if (lhs->sourceOrder != rhs->sourceOrder)
return lhs->sourceOrder < rhs->sourceOrder;
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
});
for (RemoteSendInfo* sendInfo : pairSends.second) {
int64_t channelId = (*nextChannelId)++;
sendInfo->channelInfo.channelId = channelId;
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
}
}
for (const auto& taskSends : remoteSendsByTask) {
for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo& sendInfo : sendInfos) {
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
}
}
}
for (auto& taskSends : remoteSendsByTask) {
for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (!pairsNeedingReceiveReorder.contains(pairKey))
continue;
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId);
receiveQueuesByCpu[targetCpu][pairKey].push_back(
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
}
}
}
for (auto& cpuQueues : receiveQueuesByCpu) {
for (auto& pairQueue : cpuQueues.second) {
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
if (lhs.sourceOrder != rhs.sourceOrder)
return lhs.sourceOrder < rhs.sourceOrder;
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
});
}
}
}
void createCpuComputeOps() {
IRRewriter rewriter(func.getContext());
for (size_t cpu : orderedCpus) {
SmallVector<Value> operands;
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
llvm::append_range(operands, cpuWeights[cpu]);
llvm::append_range(operands, cpuExternalInputs[cpu]);
SmallVector<Type> resultTypes;
resultTypes.reserve(cpuExternalOutputs[cpu].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
}
rewriter.setInsertionPoint(returnOp);
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
for (Value input : cpuExternalInputs[cpu]) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(loc);
}
Block* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
CpuProgram program;
program.op = newCompute;
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
program.weightToIndex[weight] = weightIndex;
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex);
}
cpuPrograms[cpu] = std::move(program);
}
}
FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
size_t cpu,
DenseMap<uint64_t, size_t>& receiveQueueIndices,
DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
const ChannelInfo& requestedChannelInfo,
ComputeInstance requestedConsumer,
size_t requestedInputIndex) {
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
if (cpuQueuesIt == receiveQueuesByCpu.end())
return failure();
auto queueIt = cpuQueuesIt->second.find(pairKey);
if (queueIt == cpuQueuesIt->second.end())
return failure();
auto& queue = queueIt->second;
size_t& queueIndex = receiveQueueIndices[pairKey];
while (queueIndex < queue.size()) {
const RemoteReceiveEntry& entry = queue[queueIndex++];
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
if (consumerTaskIt == taskByComputeInstance.end())
return failure();
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
if (consumerInputs.size() <= entry.inputIndex)
return failure();
Type inputType = consumerInputs[entry.inputIndex].getType();
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
loc,
inputType,
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
if (receivedInputs.size() <= entry.inputIndex)
receivedInputs.resize(entry.inputIndex + 1);
receivedInputs[entry.inputIndex] = receive.getResult();
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
return receive.getResult();
}
return failure();
}
LogicalResult cloneTaskBodies() {
for (size_t cpu : orderedCpus) {
CpuProgram& program = cpuPrograms[cpu];
IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
DenseMap<uint64_t, size_t> receiveQueueIndices;
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
auto inputsIt = preReceivedInputsByTask.find(consumer);
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
for (const ScheduledTask& task : tasksByCpu[cpu]) {
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
SmallVector<Value> resolvedInputs;
resolvedInputs.reserve(taskInputs.size());
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto producerRef = getProducerValueRef(input);
if (producerRef) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu == cpu) {
auto producedIt = producedValuesByTask.find(producerRef->instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization")
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
<< " producerLaneStart=" << producerRef->instance.laneStart
<< " producerLaneCount=" << producerRef->instance.laneCount;
return failure();
}
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
continue;
}
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
resolvedInputs.push_back(*preReceived);
continue;
}
FailureOr<Value> received = receiveThroughInput(rewriter,
cpu,
receiveQueueIndices,
preReceivedInputsByTask,
channelInfo,
task.computeInstance,
inputIndex);
if (failed(received)) {
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
return failure();
}
resolvedInputs.push_back(*received);
continue;
}
auto receive =
spatial::SpatChannelReceiveOp::create(rewriter,
loc,
input.getType(),
rewriter.getI64IntegerAttr(channelInfo.channelId),
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(channelInfo.targetCoreId));
resolvedInputs.push_back(receive.getResult());
continue;
}
}
resolvedInputs.push_back(program.externalInputMap.at(input));
}
SmallVector<Value> taskYieldValues;
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
if (isa<SpatCompute>(task.computeInstance.op)) {
IRMapping mapper;
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
mapper.map(oldArg, resolvedInputs[argIndex]);
for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
Operation* clonedOp = rewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
}
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
}
}
}
else {
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
IRMapping mapper;
if (templateBlock.getNumArguments() == 1)
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
Operation* clonedOp = rewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
if (oldWeightedMvmOp.getWeightIndex() != 0) {
task.computeInstance.op->emitOpError(
"batched per-cpu merge materialization expects lane-local weight index 0");
return failure();
}
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
}
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
if (oldWeightedVmmOp.getWeightIndex() != 0) {
task.computeInstance.op->emitOpError(
"batched per-cpu merge materialization expects lane-local weight index 0");
return failure();
}
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
}
}
}
}
producedValuesByTask[task.computeInstance] = taskYieldValues;
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
if (sendInfos.empty())
continue;
Value producedValue = taskYieldValues[resultIndex];
for (const RemoteSendInfo& sendInfo : sendInfos) {
spatial::SpatChannelSendOp::create(rewriter,
loc,
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
producedValue);
}
}
}
}
SmallVector<Value> yieldValues;
yieldValues.reserve(cpuExternalOutputs[cpu].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
auto producedIt = producedValuesByTask.find(outputRef.instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
return failure();
}
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
}
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
}
return success();
}
void replaceExternalUses() {
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
if (!oldComputeOps.contains(use.getOwner()))
use.assign(newValue);
}
}
LogicalResult eraseOldScheduledOps() {
SmallVector<Operation*> orderedOpsToErase;
for (Operation& op : func.getBody().front())
if (oldComputeOps.contains(&op))
orderedOpsToErase.push_back(&op);
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
SmallVector<Operation*> remainingUsers;
for (Value result : op->getResults())
for (Operation* user : result.getUsers())
remainingUsers.push_back(user);
if (!remainingUsers.empty()) {
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
<< "; erase-set=" << (oldComputeOps.contains(op) ? "yes" : "no");
for (Operation* user : remainingUsers) {
diagnostic.attachNote(user->getLoc())
<< "remaining user " << user->getName() << "; erase-set=" << (oldComputeOps.contains(user) ? "yes" : "no");
}
return failure();
}
op->erase();
}
return success();
}
void moveExternalUsersBeforeReturn() {
SmallVector<Operation*> orderedUsersToMove;
for (Operation& op : func.getBody().front()) {
if (&op == returnOp.getOperation())
break;
if (externalUsersToMove.contains(&op))
orderedUsersToMove.push_back(&op);
}
for (Operation* op : orderedUsersToMove)
op->moveBefore(returnOp);
}
func::FuncOp func;
const MergeScheduleResult* schedule = nullptr;
int64_t* nextChannelId = nullptr;
Location loc;
func::ReturnOp returnOp;
SmallVector<ScheduledTask> scheduledTasks;
DenseSet<Operation*> oldComputeOps;
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
SmallVector<size_t> orderedCpus;
DenseSet<size_t> seenCpus;
DenseSet<Operation*> externalUsersToMove;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
DenseMap<size_t, SmallVector<Value>> cpuWeights;
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
DenseSet<uint64_t> pairsNeedingReceiveReorder;
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
DenseMap<size_t, CpuProgram> cpuPrograms;
DenseMap<Value, Value> oldToNewExternalValueMap;
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
};
} // namespace
LogicalResult
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include "Scheduling/MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
class MergeScheduleMaterializer {
public:
mlir::LogicalResult
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
};
} // namespace spatial
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,459 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <chrono>
#include <cstdlib>
#include <limits>
#include <optional>
#include "PostMergeCompaction.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
return std::nullopt;
}
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
return static_cast<uint64_t>(phaseAttr.getInt());
return std::nullopt;
}
struct RebatchKey {
unsigned inputCount = 0;
unsigned resultCount = 0;
unsigned weightCount = 0;
uint64_t phase = 0;
bool hasPhase = false;
uint64_t structureHash = 0;
bool operator==(const RebatchKey& other) const {
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
}
};
struct RebatchKeyInfo {
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
static unsigned getHashValue(const RebatchKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
}
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
};
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
RebatchKey computeRebatchKey(SpatCompute compute) {
llvm::hash_code structureHash =
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
for (Value weight : compute.getWeights())
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
structureHash = llvm::hash_combine(structureHash, *phase);
Block& body = compute.getBody().front();
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
for (BlockArgument arg : body.getArguments())
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
for (Operation& op : body) {
structureHash = llvm::hash_combine(
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
for (Type type : op.getResultTypes())
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
for (NamedAttribute attr : op.getAttrs())
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
}
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
return {static_cast<unsigned>(compute.getInputs().size()),
static_cast<unsigned>(compute.getResultTypes().size()),
static_cast<unsigned>(compute.getWeights().size()),
phase.value_or(0),
phase.has_value(),
static_cast<uint64_t>(structureHash)};
}
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
if (!lhs || !rhs)
return false;
if (lhs.getInputs().size() != rhs.getInputs().size())
return false;
if (lhs.getResultTypes() != rhs.getResultTypes())
return false;
if (lhs.getWeights().size() != rhs.getWeights().size())
return false;
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
return false;
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
return false;
auto& lhsBlock = lhs.getBody().front();
auto& rhsBlock = rhs.getBody().front();
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
return false;
DenseMap<Value, Value> mappedValues;
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
if (lhsArg.getType() != rhsArg.getType())
return false;
mappedValues[lhsArg] = rhsArg;
}
auto lhsIt = lhsBlock.begin();
auto rhsIt = rhsBlock.begin();
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
Operation& lhsOp = *lhsIt;
Operation& rhsOp = *rhsIt;
if (lhsOp.getName() != rhsOp.getName())
return false;
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
return false;
if (lhsOp.getNumResults() != rhsOp.getNumResults())
return false;
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
return false;
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
auto mapped = mappedValues.find(lhsOperand);
if (mapped != mappedValues.end()) {
if (mapped->second != rhsOperand)
return false;
continue;
}
if (lhsOperand != rhsOperand)
return false;
}
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
return false;
}
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
return false;
}
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
return false;
}
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
return false;
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
mappedValues[lhsResult] = rhsResult;
}
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseSet<Operation*> consumed;
DenseMap<Operation*, size_t> computeOrder;
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
for (auto [index, compute] : llvm::enumerate(computes)) {
computeOrder[compute.getOperation()] = index;
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
}
for (size_t index = 0; index < computes.size(); ++index) {
auto anchor = computes[index];
if (consumed.contains(anchor))
continue;
if (anchor.getInputs().size() > 1)
continue;
if (!anchor.getResults().empty())
continue;
SmallVector<SpatCompute> group {anchor};
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
if (auto coreId = getComputeCoreId(anchor))
usedCoreIds.insert(*coreId);
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
if (bucketIt == candidatesByKey.end())
continue;
for (auto candidate : bucketIt->second) {
if (computeOrder.lookup(candidate.getOperation()) <= index)
continue;
if (consumed.contains(candidate))
continue;
if (!areEquivalentForRebatch(anchor, candidate))
continue;
if (auto coreId = getComputeCoreId(candidate))
if (!usedCoreIds.insert(*coreId).second)
continue;
group.push_back(candidate);
}
if (group.size() <= 1)
continue;
auto insertionAnchor = group.front();
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
llvm::stable_sort(
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
}
SmallVector<Value> weights;
weights.reserve(group.size() * anchor.getWeights().size());
SmallVector<Value> inputs;
inputs.reserve(group.size() * anchor.getInputs().size());
SmallVector<int32_t> coreIds;
coreIds.reserve(group.size());
bool haveAllCoreIds = true;
for (auto compute : group) {
llvm::append_range(weights, compute.getWeights());
llvm::append_range(inputs, compute.getInputs());
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
haveAllCoreIds = false;
else if (haveAllCoreIds)
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
}
rewriter.setInsertionPoint(insertionAnchor);
auto rebatched = SpatComputeBatch::create(rewriter,
insertionAnchor.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
ValueRange(weights),
ValueRange(inputs));
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
auto* newBlock =
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(newBlock);
IRMapping mapper;
auto& anchorBlock = anchor.getBody().front();
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
for (Operation& anchorOp : anchorBlock) {
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
entries.push_back(
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchReceiveEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getOutput().getType(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
continue;
}
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchSendEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
spatial::SpatChannelSendBatchOp::create(rewriter,
sendOp.getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
mapper.lookup(sendOp.getInput()));
continue;
}
if (isa<spatial::SpatYieldOp>(anchorOp)) {
for (auto& opIt : opIts)
++opIt;
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
continue;
}
Operation* cloned = rewriter.clone(anchorOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
for (auto& opIt : opIts)
++opIt;
}
for (auto compute : group) {
compute->removeAttr(kRebatchPhaseAttrName);
consumed.insert(compute);
rewriter.eraseOp(compute);
}
}
for (auto compute : funcOp.getOps<SpatCompute>())
compute->removeAttr(kRebatchPhaseAttrName);
}
void cleanupDeadPackingOps(func::FuncOp funcOp) {
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
} // namespace
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
{
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
orderBilateralChannelOps(funcOp);
}
{
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
rebatchEquivalentComputes(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-regular-op-runs");
compactRegularOpRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
compactRowWiseWvmmRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp);
}
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,12 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
} // namespace onnx_mlir
@@ -7,12 +7,13 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <tuple>
#include "RegularOpCompaction.hpp" #include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -42,6 +43,47 @@ struct RegularChunk {
Value output; Value output;
}; };
struct RegularCompactionResult {
bool changed = false;
Operation* resumeAfter = nullptr;
};
template <typename OpTy>
struct ConsecutiveRun {
SmallVector<OpTy> ops;
Block::iterator end;
};
template <typename OpTy, typename Predicate>
static ConsecutiveRun<OpTy>
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
ConsecutiveRun<OpTy> run;
run.end = start;
while (run.end != blockEnd) {
auto current = dyn_cast<OpTy>(&*run.end);
if (!current || !predicate(current))
break;
run.ops.push_back(current);
++run.end;
}
return run;
}
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
}
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
SmallVectorImpl<int32_t>& sourceCoreIds,
SmallVectorImpl<int32_t>& targetCoreIds,
uint64_t channelId,
uint32_t sourceCoreId,
uint32_t targetCoreId) {
channelIds.push_back(static_cast<int64_t>(channelId));
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
}
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) { static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
if (values.empty() || !values.front().hasOneUse()) if (values.empty() || !values.front().hasOneUse())
return {}; return {};
@@ -168,6 +210,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
} }
static bool isForwardedChannelPayload(Value value, Block& block) {
Operation* op = value.getDefiningOp();
if (!op || op->getBlock() != &block)
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
}
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) { static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
RegularChunk chunk; RegularChunk chunk;
chunk.startOp = startOp.getOperation(); chunk.startOp = startOp.getOperation();
@@ -202,9 +255,10 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
return chunk; return chunk;
} }
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) { static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
assert(!run.empty() && "expected a non-empty regular chunk run"); assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front(); const RegularChunk& anchorChunk = run.front();
RegularCompactionResult result;
SmallVector<Value> inputs; SmallVector<Value> inputs;
inputs.reserve(run.size()); inputs.reserve(run.size());
@@ -214,7 +268,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
rewriter.setInsertionPoint(anchorChunk.startOp); rewriter.setInsertionPoint(anchorChunk.startOp);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
if (!packedInput) if (!packedInput)
return; return result;
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType()); auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType()); auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
@@ -317,10 +371,79 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
llvm::append_range(opsToErase, chunk.ops); llvm::append_range(opsToErase, chunk.ops);
for (Operation* op : llvm::reverse(opsToErase)) for (Operation* op : llvm::reverse(opsToErase))
rewriter.eraseOp(op); rewriter.eraseOp(op);
result.changed = true;
result.resumeAfter = loop.getOperation()->getNextNode();
return result;
} }
} // namespace } // namespace
void orderBilateralChannelOps(func::FuncOp funcOp) {
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
continue;
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
Block& block = compute.getBody().front();
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
for (Operation& op : block) {
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
}
continue;
}
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
continue;
}
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
moves.push_back({receiveOp, firstMatchingSend->second});
}
for (auto [receiveOp, insertionPoint] : moves)
receiveOp->moveBefore(insertionPoint);
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
++it;
continue;
}
Type outputType = receiveOp.getOutput().getType();
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
return current.getOutput().getType() == outputType
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
});
if (run.ops.size() > 1) {
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
});
Block::iterator insertIt = run.end;
for (auto op : sorted)
op->moveBefore(&block, insertIt);
}
it = run.end;
}
}
}
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
@@ -329,18 +452,23 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it); auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) { if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType(); Type outputType = receiveOp.getOutput().getType();
auto runIt = it; auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
while (runIt != block.end()) { it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt); return current.getOutput().getType() == outputType;
if (!current || current.getOutput().getType() != outputType) });
bool hasRepeatedEndpoint = false;
DenseSet<uint64_t> seenEndpoints;
for (auto op : run.ops) {
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
if (!seenEndpoints.insert(endpointKey).second) {
hasRepeatedEndpoint = true;
break; break;
run.push_back(current); }
++runIt;
} }
if (run.size() > 1) { if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
struct ReceiveEntry { struct ReceiveEntry {
spatial::SpatChannelReceiveOp op; spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0; size_t originalIndex = 0;
@@ -349,13 +477,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
uint64_t channelId = 0; uint64_t channelId = 0;
}; };
SmallVector<ReceiveEntry> sortedEntries; SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size()); sortedEntries.reserve(run.ops.size());
for (auto [originalIndex, op] : llvm::enumerate(run)) for (auto [originalIndex, op] : llvm::enumerate(run.ops))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
@@ -364,13 +488,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(sortedEntries.size()); sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) { for (ReceiveEntry& entry : sortedEntries) {
(void) entry; appendChannelAttrs(
channelIds.push_back(nextChannelId++); channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType()); auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size())); auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
SmallVector<Value> sortedOutputs; SmallVector<Value> sortedOutputs;
sortedOutputs.reserve(sortedEntries.size()); sortedOutputs.reserve(sortedEntries.size());
@@ -383,10 +505,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size())) concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
: RankedTensorType {}; : RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.ops.front());
auto compactReceive = auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter, spatial::SpatChannelReceiveTensorOp::create(rewriter,
run.front().getLoc(), run.ops.front().getLoc(),
packedType, packedType,
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
@@ -403,7 +525,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc())); compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
} }
for (auto op : run) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
it = compactReceive->getIterator(); it = compactReceive->getIterator();
@@ -414,18 +536,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it); auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) { if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType(); Type inputType = sendOp.getInput().getType();
auto runIt = it; auto run =
while (runIt != block.end()) { collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt); return current.getInput().getType() == inputType;
if (!current || current.getInput().getType() != inputType) });
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) { if (run.ops.size() > 1) {
struct SendEntry { struct SendEntry {
spatial::SpatChannelSendOp op; spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0; uint32_t sourceCoreId = 0;
@@ -433,13 +550,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
uint64_t channelId = 0; uint64_t channelId = 0;
}; };
SmallVector<SendEntry> sortedEntries; SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size()); sortedEntries.reserve(run.ops.size());
for (auto op : run) for (auto op : run.ops)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
@@ -450,26 +563,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
targetCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size()); inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) { for (SendEntry& entry : sortedEntries) {
(void) entry; appendChannelAttrs(
channelIds.push_back(nextChannelId++); channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput()); inputs.push_back(entry.op.getInput());
} }
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
spatial::SpatChannelSendTensorOp::create(rewriter, spatial::SpatChannelSendTensorOp::create(rewriter,
run.front().getLoc(), run.ops.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput); packedInput);
for (auto op : run) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
it = runIt; it = run.end;
continue; continue;
} }
} }
@@ -488,32 +599,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it); auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
if (receiveOp) { if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
Type outputType = receiveOp.getOutput().getType(); Type outputType = receiveOp.getOutput().getType();
auto runIt = it; auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
while (runIt != block.end()) { it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt); return current.getOutput().getType() == outputType;
if (!current || current.getOutput().getType() != outputType) });
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) { if (run.ops.size() > 1) {
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<int32_t> targetCoreIds;
for (auto op : run) { for (auto op : run.ops) {
llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds()); llvm::append_range(targetCoreIds, op.getTargetCoreIds());
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType()); auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size())); auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
SmallVector<Value> outputs; SmallVector<Value> outputs;
outputs.reserve(run.size()); outputs.reserve(run.ops.size());
for (auto op : run) for (auto op : run.ops)
outputs.push_back(op.getOutput()); outputs.push_back(op.getOutput());
unsigned concatStartIndex = 0; unsigned concatStartIndex = 0;
@@ -522,10 +628,10 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size())) concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {}; : RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.ops.front());
auto compactReceive = auto compactReceive =
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
run.front().getLoc(), run.ops.front().getLoc(),
packedType, packedType,
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
@@ -535,11 +641,11 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter); concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
} }
else { else {
for (auto [index, op] : llvm::enumerate(run)) for (auto [index, op] : llvm::enumerate(run.ops))
op.getOutput().replaceAllUsesWith(extractPackedChunk( op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc())); compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
} }
for (auto op : run) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
it = compactReceive->getIterator(); it = compactReceive->getIterator();
@@ -550,43 +656,38 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it); auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
if (sendOp) { if (sendOp) {
SmallVector<spatial::SpatChannelSendBatchOp> run;
Type inputType = sendOp.getInput().getType(); Type inputType = sendOp.getInput().getType();
auto runIt = it; auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
while (runIt != block.end()) { it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt); return current.getInput().getType() == inputType;
if (!current || current.getInput().getType() != inputType) });
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) { if (run.ops.size() > 1) {
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs; SmallVector<Value> inputs;
inputs.reserve(run.size()); inputs.reserve(run.ops.size());
for (auto op : run) { for (auto op : run.ops) {
llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds()); llvm::append_range(targetCoreIds, op.getTargetCoreIds());
inputs.push_back(op.getInput()); inputs.push_back(op.getInput());
} }
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
spatial::SpatChannelSendTensorBatchOp::create(rewriter, spatial::SpatChannelSendTensorBatchOp::create(rewriter,
run.front().getLoc(), run.ops.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput); packedInput);
for (auto op : run) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
it = runIt; it = run.end;
continue; continue;
} }
} }
@@ -614,8 +715,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
continue; continue;
} }
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
SmallVector<RegularChunk> run {*anchorChunk}; SmallVector<RegularChunk> run {*anchorChunk};
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size())); auto runIt = anchorEndIt;
while (runIt != block.end()) { while (runIt != block.end()) {
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt); auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
if (!candidateStart) if (!candidateStart)
@@ -630,12 +732,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
} }
if (run.size() <= 1) { if (run.size() <= 1) {
++it; it = anchorEndIt;
continue; continue;
} }
compactRegularChunkRun(rewriter, run); size_t originalOpCount = 0;
it = runIt; for (const RegularChunk& chunk : run)
originalOpCount += chunk.ops.size();
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
if (result.changed) {
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
if (!result.resumeAfter) {
it = block.end();
continue;
}
it = result.resumeAfter->getIterator();
continue;
}
it = anchorEndIt;
} }
}; };
@@ -666,37 +782,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
continue; continue;
} }
SmallVector<spatial::SpatVMMOp> run;
auto runIt = it;
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber()); int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
while (runIt != block.end()) { auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt); if (current.getWeightIndex() != wvmmOp.getWeightIndex()
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp || current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType() || current.getInput().getType() != wvmmOp.getInput().getType()
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) { || current.getOutput().getType() != wvmmOp.getOutput().getType())
break; return false;
}
auto currentRow = dyn_cast<OpResult>(current.getInput()); auto currentRow = dyn_cast<OpResult>(current.getInput());
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow)) if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
break; return false;
run.push_back(current);
++expectedRow; ++expectedRow;
++runIt; return true;
} });
if (run.size() <= 1) { if (run.ops.size() <= 1) {
++it; ++it;
continue; continue;
} }
if (!run.front().getOutput().hasOneUse()) { if (!run.ops.front().getOutput().hasOneUse()) {
++it; ++it;
continue; continue;
} }
auto concatUse = run.front().getOutput().getUses().begin(); auto concatUse = run.ops.front().getOutput().getUses().begin();
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner()); auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
if (!concatOp) { if (!concatOp) {
++it; ++it;
@@ -705,7 +816,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
unsigned concatStartIndex = concatUse->getOperandNumber(); unsigned concatStartIndex = concatUse->getOperandNumber();
bool validConcatRun = true; bool validConcatRun = true;
for (auto [index, op] : llvm::enumerate(run)) { for (auto [index, op] : llvm::enumerate(run.ops)) {
if (!op.getOutput().hasOneUse()) { if (!op.getOutput().hasOneUse()) {
validConcatRun = false; validConcatRun = false;
break; break;
@@ -736,17 +847,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
} }
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber()); int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
int64_t runLength = static_cast<int64_t>(run.size()); int64_t runLength = static_cast<int64_t>(run.ops.size());
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.ops.front());
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0); auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength); auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1); auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
auto packedInit = auto packedInit =
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType()); tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
auto loop = auto loop =
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
@@ -757,41 +868,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
Value sourceRow = iv; Value sourceRow = iv;
if (firstRow != 0) { if (firstRow != 0) {
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow); auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue); sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
} }
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)}; SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto extractedRow = tensor::ExtractSliceOp::create(rewriter, auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
run.front().getLoc(), run.ops.front().getLoc(),
inputType, inputType,
extractRowsOp.getInput(), extractRowsOp.getInput(),
extractOffsets, extractOffsets,
extractSizes, extractSizes,
extractStrides); extractStrides);
auto loopWvmm = spatial::SpatVMMOp::create( auto loopWvmm = spatial::SpatVMMOp::create(
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto inserted = tensor::InsertSliceOp::create( auto inserted = tensor::InsertSliceOp::create(
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides); rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult()); scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
} }
SmallVector<Value> newConcatInputs; SmallVector<Value> newConcatInputs;
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1); newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (operandIndex == concatStartIndex) if (operandIndex == concatStartIndex)
newConcatInputs.push_back(loop.getResult(0)); newConcatInputs.push_back(loop.getResult(0));
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size()) if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
newConcatInputs.push_back(operand); newConcatInputs.push_back(operand);
} }
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); }); rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
for (auto op : run) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
it = loop->getIterator(); it = loop->getIterator();
@@ -6,6 +6,7 @@
namespace onnx_mlir { namespace onnx_mlir {
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId); void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp); void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp); void compactRegularOpRuns(mlir::func::FuncOp funcOp);
@@ -0,0 +1,189 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <limits>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "ComputeGraph.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
Weight getComputeBodyWeight(Region &body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto &block : body)
for ([[maybe_unused]] auto &op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
CrossbarUsage crossbarUsage = 0;
for (auto &block : body)
for (auto &op : block)
if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
bool isUsedAsWeightOnly(Operation *producerOp) {
if (producerOp->getNumResults() == 0)
return false;
for (Value result : producerOp->getResults()) {
if (result.use_empty())
return false;
for (Operation *user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result))
return false;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
if (!llvm::is_contained(batch.getWeights(), result))
return false;
continue;
}
return false;
}
}
return true;
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge &edge : edges) {
if (edge.source == edge.target)
continue;
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (const auto &[key, weight] : edgeWeights)
aggregatedEdges.push_back({key.first, key.second, weight});
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
if (lhs.source != rhs.source)
return lhs.source < rhs.source;
return lhs.target < rhs.target;
});
return aggregatedEdges;
}
} // namespace
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
static_cast<CrossbarUsage>(instance.laneCount));
}
ComputeGraph buildComputeGraph(Operation *entryOp) {
ComputeGraph graph;
for (Region &region : entryOp->getRegions()) {
for (Block &block : region) {
for (Operation &op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue;
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
size_t index = graph.nodes.size();
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (isUsedAsWeightOnly(batch.getOperation()))
continue;
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
}
}
}
}
}
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) {
auto producerInstance = getComputeProducerInstance(input);
if (!producerInstance)
continue;
auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
}
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
graph.successors.assign(graph.nodes.size(), {});
graph.predecessors.assign(graph.nodes.size(), {});
for (const ComputeGraphEdge &edge : graph.edges) {
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
}
return graph;
}
bool verifyAcyclic(const ComputeGraph &graph) {
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
std::queue<size_t> readyNodes;
for (size_t node = 0; node < graph.nodes.size(); ++node) {
remainingParents[node] = graph.predecessors[node].size();
if (remainingParents[node] == 0)
readyNodes.push(node);
}
size_t visited = 0;
while (!readyNodes.empty()) {
size_t node = readyNodes.front();
readyNodes.pop();
++visited;
for (const auto &[child, weight] : graph.successors[node]) {
(void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyNodes.push(child);
}
}
return visited == graph.nodes.size();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,49 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include "../DCPGraph/Utils.hpp"
#include "ComputeInstance.hpp"
#include "ComputeInstanceUtils.hpp"
namespace onnx_mlir {
namespace spatial {
struct ComputeGraphNode {
ComputeInstance instance;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
size_t originalOrder = 0;
};
struct ComputeGraphEdge {
size_t source = 0;
size_t target = 0;
Weight transferCost = 0;
};
struct ComputeGraph {
llvm::SmallVector<ComputeGraphNode> nodes;
llvm::SmallVector<ComputeGraphEdge> edges;
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
};
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
bool verifyAcyclic(const ComputeGraph &graph);
Weight getComputeInstanceWeight(const ComputeInstance &instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include <cstdint>
namespace onnx_mlir {
namespace spatial {
struct ComputeInstance {
mlir::Operation *op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance &other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
} // namespace spatial
} // namespace onnx_mlir
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
}
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
const onnx_mlir::spatial::ComputeInstance &rhs) {
return lhs == rhs;
}
};
} // namespace llvm
@@ -0,0 +1,151 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <limits>
#include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
}
SpatCompute getOriginalSpatCompute(Operation *op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
return dyn_cast<SpatCompute>(op);
}
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return std::nullopt;
//TODO Extract Slice is not the only global non compute operation. There are other legal op
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto compute = dyn_cast<SpatCompute>(op)) {
return ProducerValueRef {
ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
};
}
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
ComputeInstance instance = getBatchChunkForLane(batch, lane);
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
return ProducerValueRef {instance, resultIndex};
}
return std::nullopt;
}
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
return producer->instance;
return std::nullopt;
}
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
llvm::SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
if (!batch.getInputs().empty())
inputs.push_back(batch.getInputs()[lane]);
return inputs;
}
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
auto batch = cast<SpatComputeBatch>(instance.op);
llvm::SmallVector<Value, 4> weights;
weights.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
weights.push_back(batch.getWeights()[lane]);
return weights;
}
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
auto batch = cast<SpatComputeBatch>(instance.op);
llvm::SmallVector<Value, 4> outputs;
outputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
if (!batch.getOutputs().empty())
outputs.push_back(batch.getOutputs()[lane]);
return outputs;
}
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
llvm::SmallVector<Type, 4> outputTypes;
for (Value output : getComputeInstanceOutputValues(instance))
outputTypes.push_back(output.getType());
return outputTypes;
}
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return compute.getBody().front();
return cast<SpatComputeBatch>(instance.op).getBody().front();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,40 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include "ComputeInstance.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace spatial {
struct ProducerValueRef {
ComputeInstance instance;
size_t resultIndex = 0;
};
size_t getSchedulingCpuBudget();
size_t getBatchChunkTargetCount(int32_t laneCount);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,720 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdlib>
#include <limits>
#include <numeric>
#include <optional>
#include <queue>
#include <vector>
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode {
llvm::SmallVector<size_t, 4> originalNodeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
if (options.processorCount > 0)
return options.processorCount;
return std::numeric_limits<size_t>::max();
}
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
VirtualGraph virtualGraph;
virtualGraph.nodes.reserve(graph.nodes.size());
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
VirtualNode virtualNode;
virtualNode.originalNodeIndices.push_back(index);
virtualNode.weight = node.weight;
virtualNode.crossbarUsage = node.crossbarUsage;
virtualGraph.nodes.push_back(std::move(virtualNode));
}
std::vector<IndexedEdge> edges;
edges.reserve(graph.edges.size());
for (const ComputeGraphEdge &edge : graph.edges)
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
virtualGraph.edges = aggregateEdges(edges);
return virtualGraph;
}
TimingInfo computeTiming(const VirtualGraph &graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode &node = graph.nodes[nodeIndex];
if (!node.originalNodeIndices.empty())
return node.originalNodeIndices.front();
return nodeIndex;
};
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
size_t lhsKey = getVirtualNodeOrderKey(lhs);
size_t rhsKey = getVirtualNodeOrderKey(rhs);
if (lhsKey != rhsKey)
return lhsKey > rhsKey;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push(i);
while (!readyNodes.empty()) {
size_t current = readyNodes.top();
readyNodes.pop();
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto &neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
llvm::ArrayRef<size_t> selectedNodes,
const DcpScheduleOptions &options,
mlir::MLIRContext *context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
windowNodeOrderKeys.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
}
GraphDCP windowGraph(
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
if (options.processorCount > 0)
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto &task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph &graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph &coarsenedGraph,
std::vector<size_t> &oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto &mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode &memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
}
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
nodeIndexByInstance.reserve(graph.nodes.size());
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
nodeIndexByInstance[node.instance] = nodeIndex;
struct ScheduledEdge {
size_t target = 0;
Time delay = 0;
};
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
for (const ComputeGraphEdge &edge : graph.edges) {
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
Time delay = graph.nodes[edge.source].weight;
if (sourceCpu != targetCpu)
delay = addOrMax(delay, edge.transferCost);
scheduledChildren[edge.source].push_back({edge.target, delay});
incomingEdgeCount[edge.target]++;
}
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
for (const ComputeGraphNode &node : graph.nodes) {
size_t cpu = result.computeToCpuMap.lookup(node.instance);
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
});
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
size_t sourceIndex = scheduledTasks[i - 1].second;
size_t targetIndex = scheduledTasks[i].second;
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
incomingEdgeCount[targetIndex]++;
}
}
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
if (incomingEdgeCount[nodeIndex] == 0)
readyNodes.push(nodeIndex);
std::vector<Time> startTimes(graph.nodes.size(), 0);
size_t processedNodeCount = 0;
while (!readyNodes.empty()) {
size_t sourceIndex = readyNodes.top();
readyNodes.pop();
processedNodeCount++;
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
incomingEdgeCount[edge.target]--;
if (incomingEdgeCount[edge.target] == 0)
readyNodes.push(edge.target);
}
}
if (processedNodeCount != graph.nodes.size())
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
}
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
MergeScheduleResult result;
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid)
virtualNodeOrder = std::move(timing.topologicalOrder);
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalNodeIndices)
originalNodeToCpu[originalIndex] = cpu;
}
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
llvm::DenseMap<size_t, size_t> nextCpuSlot;
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
size_t cpu = originalNodeToCpu[originalIndex];
result.dominanceOrderCompute.push_back(node.instance);
result.computeToCpuMap[node.instance] = cpu;
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
result.cpuToLastComputeMap[cpu] = node.instance;
}
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
assignFeasibleAest(originalGraph, result);
return result;
}
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(graph.nodes.size());
for (const ComputeGraphNode &node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance);
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
result.computeToCpuMap[instance] = cpu;
result.computeToCpuSlotMap[instance] = slot;
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
}
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
result.cpuToLastComputeMap[cpu] = lastInstance;
result.isLastComputeOfCpu.insert(lastInstance);
}
return result;
}
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys;
llvm::SmallVector<IndexedEdge> edges;
nodeWeights.reserve(graph.nodes.size());
nodeCrossbarUsage.reserve(graph.nodes.size());
nodeOrderKeys.reserve(graph.nodes.size());
edges.reserve(graph.edges.size());
for (const ComputeGraphNode &node : graph.nodes) {
nodeWeights.push_back(node.weight);
nodeCrossbarUsage.push_back(node.crossbarUsage);
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
}
for (const ComputeGraphEdge &edge : graph.edges) {
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (options.processorCount > 0)
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
graphDCP.setContext(context);
graphDCP.runDcp();
return buildResultFromScheduledGraph(graphDCP, graph);
}
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
return false;
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
});
}
} // namespace
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
if (needsExactScheduledBatches(graph, options))
return runLegacyDcp(graph, options, context);
if (options.criticalWindowSize == 0)
return runLegacyDcp(graph, options, context);
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
if (windowSchedule.mergeGroups.empty()) {
if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
}
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
}
llvm::SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, graph);
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "ComputeGraph.hpp"
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
struct DcpScheduleOptions {
size_t processorCount = 0;
size_t criticalWindowSize = 0;
bool allowFallbackForAutoCoreCount = true;
};
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,25 @@
#pragma once
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstddef>
#include <cstdint>
#include <vector>
#include "ComputeInstance.hpp"
namespace onnx_mlir {
namespace spatial {
struct MergeScheduleResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
};
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,139 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp"
#include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
namespace spatial {
namespace {
MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft:
return MergeSchedulerKind::Peft;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
}
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
if (!result.computeToCpuMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
if (!result.computeToCpuSlotMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
if (!result.computeToAestMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing start time");
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
});
CrossbarUsage usedCrossbars = 0;
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
if (scheduledTasks[slot].first != slot)
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
if (usedCrossbars > crossbarCapacity)
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
}
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
if (!result.isLastComputeOfCpu.count(expectedLast))
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
}
for (const ComputeGraphEdge &edge : graph.edges) {
const ComputeInstance source = graph.nodes[edge.source].instance;
const ComputeInstance target = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
const size_t targetCpu = result.computeToCpuMap.lookup(target);
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
graph.nodes[edge.target].originalOrder)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
}
}
} // namespace
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
MergeScheduleResult MergeSchedulingAnalysis::run() {
verifyExplicitPimCoreCount();
ComputeGraph graph = buildComputeGraph(entryOp);
if (!verifyAcyclic(graph))
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
MergeSchedulingOptions options;
options.kind = getSchedulerKind();
if (coresCount.getValue() > 0)
options.processorCount = static_cast<size_t>(coresCount.getValue());
MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler(
graph,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
else {
schedule = runDcpScheduler(
graph,
DcpScheduleOptions {
options.processorCount,
dcpCriticalWindowSize.getValue(),
options.allowDcpFallbackForAutoCoreCount
},
entryOp->getContext());
}
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
return schedule;
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,36 @@
#pragma once
#include "mlir/IR/Operation.h"
#include <cstddef>
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
enum class MergeSchedulerKind {
Dcp,
Peft,
};
struct MergeSchedulingOptions {
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
size_t processorCount = 0;
bool allowDcpFallbackForAutoCoreCount = true;
};
class MergeSchedulingAnalysis {
public:
explicit MergeSchedulingAnalysis(mlir::Operation *op);
MergeScheduleResult &getResult() { return result; }
private:
mlir::Operation *entryOp = nullptr;
MergeScheduleResult result;
MergeScheduleResult run();
};
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,303 @@
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <queue>
#include <vector>
#include "PeftScheduler.hpp"
namespace onnx_mlir {
namespace spatial {
namespace {
struct ScheduledTask {
size_t processor = std::numeric_limits<size_t>::max();
Time startTime = 0;
Time endTime = 0;
size_t slot = 0;
};
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
std::queue<size_t> readySinks;
std::vector<std::vector<size_t>> reverseLevels;
for (size_t node = 0; node < graph.nodes.size(); ++node) {
remainingSuccessors[node] = graph.successors[node].size();
if (remainingSuccessors[node] == 0)
readySinks.push(node);
}
size_t levelizedCount = 0;
while (!readySinks.empty()) {
size_t levelSize = readySinks.size();
std::vector<size_t> levelNodes;
levelNodes.reserve(levelSize);
for (size_t i = 0; i < levelSize; ++i) {
size_t node = readySinks.front();
readySinks.pop();
levelNodes.push_back(node);
++levelizedCount;
for (const auto& [pred, weight] : graph.predecessors[node]) {
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
if (--remainingSuccessors[pred] == 0)
readySinks.push(pred);
}
}
reverseLevels.push_back(std::move(levelNodes));
}
if (levelizedCount != graph.nodes.size())
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
return reverseLevels;
}
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
constexpr size_t kMaxOctTableBytes = 1ull << 30;
if (nodeCount == 0 || processorCount == 0)
return;
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
size_t rowBytes = processorCount * sizeof(Time);
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
size_t totalBytes = nodeCount * rowBytes;
if (totalBytes > kMaxOctTableBytes) {
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
totalBytes / (1024 * 1024))
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
}
} // namespace
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
const size_t nodeCount = graph.nodes.size();
const size_t processorCount = options.processorCount;
if (processorCount == 0)
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
verifyOctTableSize(nodeCount, processorCount);
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
// MOCK: Replace this with your actual heterogeneous cost lookup.
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
// 1. O(P(E+V)) Heterogeneous OCT Calculation
for (const std::vector<size_t>& levelNodes : reverseLevels) {
auto computeNodeOct = [&](size_t levelIndex) {
size_t task = levelNodes[levelIndex];
std::vector<Time> maxVals(processorCount, 0);
for (const auto& [succ, comm] : graph.successors[task]) {
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
for (size_t processor = 0; processor < processorCount; ++processor) {
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
maxVals[processor] = std::max(maxVals[processor], bestSucc);
}
}
Time minForPreds = std::numeric_limits<Time>::max();
for (size_t processor = 0; processor < processorCount; ++processor) {
oct[task * processorCount + processor] = maxVals[processor];
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
}
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
};
if (options.context != nullptr)
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
else
for (size_t i = 0; i < levelNodes.size(); ++i)
computeNodeOct(i);
}
struct RankEntry {
long double rank = 0.0L;
size_t node = 0;
size_t originalOrder = 0;
};
std::vector<RankEntry> ranks(nodeCount);
auto computeRank = [&](size_t node) {
long double rank = 0.0L;
for (size_t processor = 0; processor < processorCount; ++processor)
rank += static_cast<long double>(oct[node * processorCount + processor]);
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
};
if (options.context != nullptr)
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
else
for (size_t node = 0; node < nodeCount; ++node)
computeRank(node);
auto readyCompare = [&](size_t lhs, size_t rhs) {
const RankEntry& lhsRank = ranks[lhs];
const RankEntry& rhsRank = ranks[rhs];
if (lhsRank.rank != rhsRank.rank)
return lhsRank.rank < rhsRank.rank;
if (lhsRank.originalOrder != rhsRank.originalOrder)
return lhsRank.originalOrder > rhsRank.originalOrder;
return lhs > rhs;
};
std::vector<int> remainingParents(nodeCount, 0);
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
for (size_t node = 0; node < nodeCount; ++node) {
remainingParents[node] = graph.predecessors[node].size();
if (remainingParents[node] == 0)
readyQueue.push(node);
}
std::vector<char> scheduled(nodeCount, false);
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
std::vector<ScheduledTask> schedules(nodeCount);
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
size_t scheduledCount = 0;
while (!readyQueue.empty()) {
size_t task = readyQueue.top();
readyQueue.pop();
if (scheduled[task])
continue;
size_t bestProcessor = std::numeric_limits<size_t>::max();
Time bestEst = 0;
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
bool crossbarRejected = false;
for (size_t processor = 0; processor < processorCount; ++processor) {
if (graph.nodes[task].crossbarUsage != 0
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
crossbarRejected = true;
continue;
}
Time dataReady = 0;
for (const auto& [pred, comm] : graph.predecessors[task]) {
const ScheduledTask& predSchedule = schedules[pred];
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
}
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
Time compWeight = getComputeCost(task, processor);
Time est = dataReady;
Time currentEnd = 0;
bool foundGap = false;
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
const ScheduledTask& schedTask = schedules[schedTaskIndex];
Time gapStart = std::max(currentEnd, dataReady);
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
est = gapStart;
foundGap = true;
break;
}
currentEnd = schedTask.endTime;
}
if (!foundGap)
est = std::max(currentEnd, dataReady);
Time eft = addOrMax(est, compWeight);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
}
}
if (bestProcessor == std::numeric_limits<size_t>::max()) {
if (crossbarRejected) {
std::string message =
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
graph.nodes[task].originalOrder,
options.crossbarCapacity)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
graph.nodes[task].originalOrder,
processorCount)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
scheduled[task] = true;
++scheduledCount;
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
// 3. CRITICAL FIX: Topological Append
// Because the readyQueue pops in strict topological order, simply pushing to the
// back guarantees the Monoliths will be physically generated cycle-free.
// The hardware will still benefit from the processor assignment chosen by PEFT.
tasksByProcessor[bestProcessor].push_back(task);
for (const auto& [child, weight] : graph.successors[task]) {
(void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyQueue.push(child);
}
}
if (scheduledCount != nodeCount)
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
// 4. Build Strict Topological Dominance Order
std::vector<size_t> scheduledOrder(nodeCount);
for (size_t i = 0; i < nodeCount; ++i)
scheduledOrder[i] = i;
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
});
// 5. Populate Final Result
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(nodeCount);
for (size_t task : scheduledOrder)
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
for (size_t processor = 0; processor < processorCount; ++processor) {
size_t currentSlot = 0;
for (size_t task : tasksByProcessor[processor]) {
const ComputeInstance instance = graph.nodes[task].instance;
result.computeToCpuMap[instance] = processor;
result.computeToCpuSlotMap[instance] = currentSlot++;
result.computeToAestMap[instance] = schedules[task].startTime;
}
if (!tasksByProcessor[processor].empty()) {
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
result.cpuToLastComputeMap[processor] = lastInstance;
result.isLastComputeOfCpu.insert(lastInstance);
}
}
return result;
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "ComputeGraph.hpp"
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
struct PeftScheduleOptions {
size_t processorCount = 0;
CrossbarUsage crossbarCapacity = 0;
mlir::MLIRContext *context = nullptr;
};
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
} // namespace spatial
} // namespace onnx_mlir
+4 -1
View File
@@ -24,8 +24,11 @@ struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
createDirectory(pimDir); createDirectory(pimDir);
int compiler_error_code = compileToPimCode(moduleOp, pimDir); int compiler_error_code = compileToPimCode(moduleOp, pimDir);
if (compiler_error_code != CompilerSuccess) if (compiler_error_code != CompilerSuccess) {
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
<< compiler_error_code;
signalPassFailure(); signalPassFailure();
}
} }
}; };
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
} }
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation();
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.enableFolding(); config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) { if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
signalPassFailure(); signalPassFailure();
return; return;
} }
dumpModule(getOperation(), "pim3_folded"); dumpModule(moduleOp, "pim3_folded");
} }
std::shared_ptr<const FrozenRewritePatternSet> patterns; std::shared_ptr<const FrozenRewritePatternSet> patterns;
@@ -66,8 +66,10 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info,
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
} }
static SmallVector<Value> static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides, PatternRewriter& rewriter) { ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
SmallVector<Value> indices; SmallVector<Value> indices;
indices.reserve(shape.size()); indices.reserve(shape.size());
@@ -112,7 +114,8 @@ static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1)); chunkSizes.push_back(rewriter.getIndexAttr(1));
} else { }
else {
chunkOffsets.push_back(info.offsets[dim]); chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
} }
@@ -122,11 +125,8 @@ static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
} }
static Value buildContiguousChunk(Value source, static Value buildContiguousChunk(
ArrayRef<int64_t> copyShape, Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets; SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes; SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides; SmallVector<OpFoldResult> chunkStrides;
@@ -203,7 +203,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices = SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {} : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
@@ -160,6 +160,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
} }
if (hasFailure) { if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure(); signalPassFailure();
return; return;
} }
+98 -46
View File
@@ -6,10 +6,11 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
using namespace mlir; using namespace mlir;
@@ -96,6 +97,22 @@ static bool isConstantGlobalView(Value value) {
value = cast.getSource(); value = cast.getSource();
continue; continue;
} }
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = expand.getSrc();
continue;
}
return false; return false;
} }
} }
@@ -152,14 +169,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
bool hasFailure = false; pim::CappedDiagnosticReporter diagnostics;
moduleOp.walk([&](Operation* op) { moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat") if (op->getDialect()->getNamespace() != "spat")
return; return;
op->emitError("illegal Spatial operation reached PIM codegen verification"); diagnostics.report(op, [](Operation* illegalOp) {
hasFailure = true; illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
});
}); });
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
@@ -168,49 +186,56 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
for (Operation& op : funcOp.getBody().front().getOperations()) { for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) { if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp))) (void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
hasFailure = true; (void) verifyCoreOperands(coreOp, diagnostics);
continue; continue;
} }
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) { if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp))) (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
hasFailure = true; (void) verifyCoreOperands(coreBatchOp, diagnostics);
continue; continue;
} }
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) { if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp))) (void) verifyReturnOp(returnOp, diagnostics);
hasFailure = true;
continue; continue;
} }
if (!isAddressOnlyHostOp(&op)) { if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; " diagnostics.report(&op, [](Operation* illegalOp) {
"fold it to constants or lower it into pim.core"); illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
hasFailure = true; "fold it to constants or lower it into pim.core");
});
continue; continue;
} }
if (failed(verifyAddressOnlyHostOp(&op))) (void) verifyAddressOnlyHostOp(&op, diagnostics);
hasFailure = true;
} }
} }
if (hasFailure) if (diagnostics.hasFailure()) {
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure(); signalPassFailure();
}
} }
private: private:
template <typename CoreOpTy> template <typename CoreOpTy>
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) { static LogicalResult
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false; bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { for (auto it : llvm::enumerate(coreOp.getWeights())) {
size_t weightIndex = it.index();
Value weight = it.value();
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>(); auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp && !isConstantGlobalView(weight)) { if (!getGlobalOp && !isConstantGlobalView(weight)) {
coreOp.emitOpError() << "weight #" << weightIndex diagnostics.report(coreOp.getOperation(), [&](Operation*) {
<< " must be materialized as a constant memref.global or a static view of one before JSON " coreOp.emitOpError() << "weight #" << weightIndex
"codegen"; << " must be materialized as a constant memref.global or a static view of one before "
"JSON codegen";
});
hasFailure = true; hasFailure = true;
continue; continue;
} }
@@ -220,14 +245,18 @@ private:
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) { if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
});
hasFailure = true; hasFailure = true;
continue; continue;
} }
if (!globalOp.getConstant() || !globalOp.getInitialValue()) { if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex diagnostics.report(coreOp.getOperation(), [&](Operation*) {
<< " must come from a constant memref.global with an initial value"; coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
});
hasFailure = true; hasFailure = true;
} }
} }
@@ -235,11 +264,15 @@ private:
return success(!hasFailure); return success(!hasFailure);
} }
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false; bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { for (auto it : llvm::enumerate(returnOp.getOperands())) {
size_t resultIndex = it.index();
Value operand = it.value();
if (!isCodegenAddressableValue(operand)) { if (!isCodegenAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage"; diagnostics.report(returnOp.getOperation(), [&](Operation*) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
});
hasFailure = true; hasFailure = true;
} }
} }
@@ -247,38 +280,50 @@ private:
} }
template <typename CoreOpTy> template <typename CoreOpTy>
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) { static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlock( return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false; bool hasFailure = false;
if (!isSupportedCoreInstructionOp(&op)) { if (!isSupportedCoreInstructionOp(&op)) {
op.emitOpError("unsupported executable op reached PIM codegen verification"); diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
});
hasFailure = true; hasFailure = true;
} }
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { for (auto it : llvm::enumerate(op.getOperands())) {
size_t operandIndex = it.index();
Value operand = it.value();
if (!isa<BaseMemRefType>(operand.getType())) if (!isa<BaseMemRefType>(operand.getType()))
continue; continue;
auto resolvedAddress = resolveContiguousAddress(operand, knowledge); auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) { if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true; hasFailure = true;
continue; continue;
} }
if (isExplicitHostOperand(&op, operandIndex)) { if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) { if (!isCodegenAddressableValue(operand, knowledge)) {
op.emitOpError() << "host operand #" << operandIndex diagnostics.report(&op, [&](Operation* illegalOp) {
<< " is not backed by contiguous addressable storage"; illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true; hasFailure = true;
} }
continue; continue;
} }
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) { if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex diagnostics.report(&op, [&](Operation* illegalOp) {
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd"; illegalOp->emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with "
"pim.memcp_hd";
});
hasFailure = true; hasFailure = true;
} }
} }
@@ -286,18 +331,20 @@ private:
}); });
} }
static LogicalResult verifyAddressOnlyHostOp(Operation* op) { static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op)) if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlyBase(op, subviewOp.getSource()); return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
if (auto castOp = dyn_cast<memref::CastOp>(op)) if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource()); return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op)) if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc()); return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op)) if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc()); return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) { if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) { if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
op->emitOpError("depends on a value that is not backed by addressable storage"); diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure(); return failure();
} }
return success(); return success();
@@ -305,19 +352,24 @@ private:
return success(); return success();
} }
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { static LogicalResult
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isCodegenAddressableValue(source)) if (isCodegenAddressableValue(source))
return success(); return success();
op->emitOpError("depends on a value that is not backed by contiguous addressable storage"); diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
});
return failure(); return failure();
} }
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) { static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isBaseAddressableValue(source)) if (isBaseAddressableValue(source))
return success(); return success();
op->emitOpError("depends on a value that is not backed by addressable storage"); diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure(); return failure();
} }
}; };

Some files were not shown because too many files have changed in this diff Show More