import re import subprocess from pathlib import Path from colorama import Fore, Style from subprocess_utils import run_command_with_reporter PIM_PASS_LABELS = ( ("ONNXToSpatialPass", "ONNX to Spatial"), ("MergeComputeNodesPass", "Merge Compute Nodes"), ("SpatialToPimPass", "Spatial to PIM"), ("PimBufferizationPass", "Bufferize PIM"), ("HostConstantFoldingPass", "Fold Host Constants"), ("MaterializeHostConstantsPass", "Materialize Host Constants"), ("VerificationPass", "Verify PIM"), ("EmitPimJsonPass", "Emit PIM JSON"), ) PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS) TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$") def _parse_pim_pass_timings(output_text): pass_timings = {} for line in output_text.splitlines(): match = TIMING_LINE_RE.match(line) if not match: continue duration = float(match.group(1)) pass_name = match.group(2) for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items(): if pass_name.endswith(suffix): pass_timings[label] = pass_timings.get(label, 0.0) + duration break if not pass_timings: raise RuntimeError("Raptor timing report did not contain any PIM pass timings.") return pass_timings def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, crossbar_size, crossbar_count, cwd=None, reporter=None): # Define the arguments, with the possibility to set crossbar size and count args = [ network_path, "-o", output_base, "--maccel=PIM", "--EmitPimCodegen", # "--use-experimental-conv-impl=true", f"--crossbar-size={crossbar_size}", f"--crossbar-count={crossbar_count}", "--enable-timing", ] try: output_text = run_command_with_reporter( [str(raptor_onnx_path)] + [str(arg) for arg in args], cwd=cwd, reporter=reporter, capture_output=True, ) if reporter is None: print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) return _parse_pim_pass_timings(output_text) except subprocess.CalledProcessError: if reporter is None: print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL) raise