Files
Raptor/validation/raptor.py

82 lines
2.7 KiB
Python

import re
import shlex
import subprocess
from pathlib import Path
from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter
PIM_PASS_LABELS = (
("ONNXToSpatialPass", "ONNX to Spatial"),
("MergeComputeNodesPass", "Merge Compute Nodes"),
("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"),
("MaterializeHostConstantsPass", "Materialize Host Constants"),
("VerificationPass", "Verify PIM"),
("EmitPimJsonPass", "Emit PIM JSON"),
)
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
def _parse_pim_pass_timings(output_text):
pass_timings = {}
for line in output_text.splitlines():
match = TIMING_LINE_RE.match(line)
if not match:
continue
duration = float(match.group(1))
pass_name = match.group(2)
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
if pass_name.endswith(suffix):
pass_timings[label] = pass_timings.get(label, 0.0) + duration
break
if not pass_timings:
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
return pass_timings
def _format_command(cmd):
return shlex.join(str(arg) for arg in cmd)
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count
args = [
network_path,
"-o",
output_base,
"--maccel=PIM",
"--EmitPimCodegen",
# "--use-experimental-conv-impl=true",
f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}",
"--enable-timing",
]
if core_count is not None:
args.append(f"--core-count={core_count}")
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
if reporter is not None:
reporter.log(f" Raptor command: {_format_command(cmd)}")
else:
print(f"Raptor command: {_format_command(cmd)}")
try:
output_text = run_command_with_reporter(
cmd,
cwd=cwd,
reporter=reporter,
capture_output=True,
)
if reporter is None:
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
return _parse_pim_pass_timings(output_text)
except subprocess.CalledProcessError:
if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
raise