validate.py now also checks pass timings

This commit is contained in:
NiccoloN
2026-04-16 16:58:44 +02:00
parent ae93d1c563
commit 831b7be4e7
4 changed files with 99 additions and 11 deletions

View File

@@ -1,8 +1,40 @@
import re
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from colorama import Fore, Style from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter from subprocess_utils import run_command_with_reporter
PIM_PASS_LABELS = (
("ONNXToSpatialPass", "ONNX to Spatial"),
("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"),
("MaterializeHostConstantsPass", "Materialize Host Constants"),
("VerificationPass", "Verify PIM"),
("EmitPimJsonPass", "Emit PIM JSON"),
)
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
def _parse_pim_pass_timings(output_text):
pass_timings = {}
for line in output_text.splitlines():
match = TIMING_LINE_RE.match(line)
if not match:
continue
duration = float(match.group(1))
pass_name = match.group(2)
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
if pass_name.endswith(suffix):
pass_timings[label] = pass_timings.get(label, 0.0) + duration
break
if not pass_timings:
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
return pass_timings
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, cwd=None, reporter=None): crossbar_size, crossbar_count, cwd=None, reporter=None):
@@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
# "--use-experimental-conv-impl=true", # "--use-experimental-conv-impl=true",
f"--crossbar-size={crossbar_size}", f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}", f"--crossbar-count={crossbar_count}",
"--enable-timing",
] ]
try: try:
run_command_with_reporter( output_text = run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args], [str(raptor_onnx_path)] + [str(arg) for arg in args],
cwd=cwd, cwd=cwd,
reporter=reporter, reporter=reporter,
capture_output=True,
) )
if reporter is None: if reporter is None:
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
return _parse_pim_pass_timings(output_text)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
if reporter is None: if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL) print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)

View File

@@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False): def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector() selector = selectors.DefaultSelector()
recent_output = bytearray() recent_output = bytearray()
captured_output = bytearray()
try: try:
selector.register(fd, selectors.EVENT_READ) selector.register(fd, selectors.EVENT_READ)
@@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
reporter._clear() reporter._clear()
os.write(1, data) os.write(1, data)
reporter._render() reporter._render()
captured_output.extend(data)
recent_output.extend(data) recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES: if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES] del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
@@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output)) raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
return bytes(captured_output)
def run_command_with_reporter(cmd, cwd=None, reporter=None): def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
if reporter is None: if reporter is None:
if capture_output:
completed = subprocess.run(
cmd,
cwd=cwd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
return completed.stdout.decode("utf-8", errors="replace")
subprocess.run(cmd, cwd=cwd, check=True) subprocess.run(cmd, cwd=cwd, check=True)
return return None
try: try:
master_fd, slave_fd = pty.openpty() master_fd, slave_fd = pty.openpty()
@@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
) )
assert process.stdout is not None assert process.stdout is not None
_stream_output(process.stdout.fileno(), process, reporter) output = _stream_output(process.stdout.fileno(), process, reporter)
return return output.decode("utf-8", errors="replace") if capture_output else None
try: try:
process = subprocess.Popen( process = subprocess.Popen(
@@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
finally: finally:
os.close(slave_fd) os.close(slave_fd)
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True) output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
return output.decode("utf-8", errors="replace") if capture_output else None

View File

@@ -8,6 +8,7 @@ import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
from raptor import PIM_PASS_LABELS
def format_command(cmd): def format_command(cmd):
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
reporter.resume() reporter.resume()
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
if timed_benchmark_count == 0:
return
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
for _, label in PIM_PASS_LABELS:
count = pass_timing_counts[label]
if count == 0:
continue
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
def main(): def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
@@ -90,11 +104,15 @@ def main():
print("=" * 72) print("=" * 72)
results = {} # relative_path -> passed results = {} # relative_path -> passed
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
total_timing_sum = 0.0
timed_benchmark_count = 0
reporter = ProgressReporter(len(onnx_files)) reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1): for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir) rel = onnx_path.relative_to(operations_dir)
try: try:
passed = validate_network( result = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold, threshold=a.threshold,
@@ -102,7 +120,15 @@ def main():
model_index=index, model_index=index,
model_total=len(onnx_files), model_total=len(onnx_files),
) )
results[str(rel)] = passed results[str(rel)] = result.passed
if result.pim_pass_timings:
benchmark_total = 0.0
for label, duration in result.pim_pass_timings.items():
pass_timing_sums[label] += duration
pass_timing_counts[label] += 1
benchmark_total += duration
total_timing_sum += benchmark_total
timed_benchmark_count += 1
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
results[str(rel)] = False results[str(rel)] = False
print_validation_error(reporter, rel, exc) print_validation_error(reporter, rel, exc)
@@ -131,6 +157,12 @@ def main():
print(separator) print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL) print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL) print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
total_timing_sum,
timed_benchmark_count,
)
sys.exit(0 if n_passed == n_total else 1) sys.exit(0 if n_passed == n_total else 1)

View File

@@ -4,6 +4,7 @@ import numpy as np
import subprocess import subprocess
import shutil import shutil
import sys import sys
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
@dataclass
class ValidationResult:
passed: bool
pim_pass_timings: dict[str, float] = field(default_factory=dict)
class ProgressReporter: class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT): def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models self.total_models = total_models
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL + reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}") f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False failed_with_exception = False
pim_pass_timings = {}
try: try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance() reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor( pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count, crossbar_size, crossbar_count,
cwd=raptor_dir, reporter=reporter) cwd=raptor_dir, reporter=reporter)
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.record_result(passed) reporter.record_result(passed)
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
except Exception: except Exception:
failed_with_exception = True failed_with_exception = True
reporter.record_result(False) reporter.record_result(False)
@@ -352,4 +360,4 @@ if __name__ == '__main__':
passed = validate_network( passed = validate_network(
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
) )
raise SystemExit(0 if passed else 1) raise SystemExit(0 if passed.passed else 1)