validate.py now also checks pass timings
This commit is contained in:
@@ -1,8 +1,40 @@
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from colorama import Fore, Style
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
PIM_PASS_LABELS = (
|
||||
("ONNXToSpatialPass", "ONNX to Spatial"),
|
||||
("SpatialToPimPass", "Spatial to PIM"),
|
||||
("PimBufferizationPass", "Bufferize PIM"),
|
||||
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
||||
("VerificationPass", "Verify PIM"),
|
||||
("EmitPimJsonPass", "Emit PIM JSON"),
|
||||
)
|
||||
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
|
||||
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
|
||||
|
||||
|
||||
def _parse_pim_pass_timings(output_text):
|
||||
pass_timings = {}
|
||||
for line in output_text.splitlines():
|
||||
match = TIMING_LINE_RE.match(line)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
duration = float(match.group(1))
|
||||
pass_name = match.group(2)
|
||||
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
|
||||
if pass_name.endswith(suffix):
|
||||
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||
break
|
||||
|
||||
if not pass_timings:
|
||||
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||
return pass_timings
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||
@@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
# "--use-experimental-conv-impl=true",
|
||||
f"--crossbar-size={crossbar_size}",
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
"--enable-timing",
|
||||
]
|
||||
|
||||
try:
|
||||
run_command_with_reporter(
|
||||
output_text = run_command_with_reporter(
|
||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||
cwd=cwd,
|
||||
reporter=reporter,
|
||||
capture_output=True,
|
||||
)
|
||||
if reporter is None:
|
||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||
return _parse_pim_pass_timings(output_text)
|
||||
except subprocess.CalledProcessError:
|
||||
if reporter is None:
|
||||
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
||||
|
||||
Reference in New Issue
Block a user