From 831b7be4e7aa85e9d4dcb12d6727ea4169e07ece Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 16 Apr 2026 16:58:44 +0200 Subject: [PATCH] validate.py now also checks pass timings --- validation/raptor.py | 37 +++++++++++++++++++++++++++++++++- validation/subprocess_utils.py | 23 ++++++++++++++++----- validation/validate.py | 36 +++++++++++++++++++++++++++++++-- validation/validate_one.py | 14 ++++++++++--- 4 files changed, 99 insertions(+), 11 deletions(-) diff --git a/validation/raptor.py b/validation/raptor.py index 0746cdf..67498e6 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -1,8 +1,40 @@ +import re import subprocess from pathlib import Path from colorama import Fore, Style from subprocess_utils import run_command_with_reporter +PIM_PASS_LABELS = ( + ("ONNXToSpatialPass", "ONNX to Spatial"), + ("SpatialToPimPass", "Spatial to PIM"), + ("PimBufferizationPass", "Bufferize PIM"), + ("HostConstantFoldingPass", "Fold Host Constants"), + ("MaterializeHostConstantsPass", "Materialize Host Constants"), + ("VerificationPass", "Verify PIM"), + ("EmitPimJsonPass", "Emit PIM JSON"), +) +PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS) +TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$") + + +def _parse_pim_pass_timings(output_text): + pass_timings = {} + for line in output_text.splitlines(): + match = TIMING_LINE_RE.match(line) + if not match: + continue + + duration = float(match.group(1)) + pass_name = match.group(2) + for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items(): + if pass_name.endswith(suffix): + pass_timings[label] = pass_timings.get(label, 0.0) + duration + break + + if not pass_timings: + raise RuntimeError("Raptor timing report did not contain any PIM pass timings.") + return pass_timings + def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, crossbar_size, crossbar_count, cwd=None, reporter=None): @@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, # "--use-experimental-conv-impl=true", f"--crossbar-size={crossbar_size}", f"--crossbar-count={crossbar_count}", + "--enable-timing", ] try: - run_command_with_reporter( + output_text = run_command_with_reporter( [str(raptor_onnx_path)] + [str(arg) for arg in args], cwd=cwd, reporter=reporter, + capture_output=True, ) if reporter is None: print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) + return _parse_pim_pass_timings(output_text) except subprocess.CalledProcessError: if reporter is None: print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL) diff --git a/validation/subprocess_utils.py b/validation/subprocess_utils.py index c1e68b6..de69ee0 100644 --- a/validation/subprocess_utils.py +++ b/validation/subprocess_utils.py @@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False): def _stream_output(fd, process, reporter, treat_eio_as_eof=False): selector = selectors.DefaultSelector() recent_output = bytearray() + captured_output = bytearray() try: selector.register(fd, selectors.EVENT_READ) @@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False): reporter._clear() os.write(1, data) reporter._render() + captured_output.extend(data) recent_output.extend(data) if len(recent_output) > MAX_ERROR_OUTPUT_BYTES: del recent_output[:-MAX_ERROR_OUTPUT_BYTES] @@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False): return_code = process.wait() if return_code != 0: raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output)) + return bytes(captured_output) -def run_command_with_reporter(cmd, cwd=None, reporter=None): +def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False): if reporter is None: + if capture_output: + completed = subprocess.run( + cmd, + cwd=cwd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + return completed.stdout.decode("utf-8", errors="replace") subprocess.run(cmd, cwd=cwd, check=True) - return + return None try: master_fd, slave_fd = pty.openpty() @@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None): stderr=subprocess.STDOUT, ) assert process.stdout is not None - _stream_output(process.stdout.fileno(), process, reporter) - return + output = _stream_output(process.stdout.fileno(), process, reporter) + return output.decode("utf-8", errors="replace") if capture_output else None try: process = subprocess.Popen( @@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None): finally: os.close(slave_fd) - _stream_output(master_fd, process, reporter, treat_eio_as_eof=True) + output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True) + return output.decode("utf-8", errors="replace") if capture_output else None diff --git a/validation/validate.py b/validation/validate.py index 3da7d47..357f6fa 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -8,6 +8,7 @@ import sys from pathlib import Path from colorama import Style, Fore from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network +from raptor import PIM_PASS_LABELS def format_command(cmd): @@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc): reporter.resume() +def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count): + if timed_benchmark_count == 0: + return + + print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL) + for _, label in PIM_PASS_LABELS: + count = pass_timing_counts[label] + if count == 0: + continue + print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s") + print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s") + + def main(): ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.") @@ -90,11 +104,15 @@ def main(): print("=" * 72) results = {} # relative_path -> passed + pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS} + pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS} + total_timing_sum = 0.0 + timed_benchmark_count = 0 reporter = ProgressReporter(len(onnx_files)) for index, onnx_path in enumerate(onnx_files, start=1): rel = onnx_path.relative_to(operations_dir) try: - passed = validate_network( + result = validate_network( onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, threshold=a.threshold, @@ -102,7 +120,15 @@ def main(): model_index=index, model_total=len(onnx_files), ) - results[str(rel)] = passed + results[str(rel)] = result.passed + if result.pim_pass_timings: + benchmark_total = 0.0 + for label, duration in result.pim_pass_timings.items(): + pass_timing_sums[label] += duration + pass_timing_counts[label] += 1 + benchmark_total += duration + total_timing_sum += benchmark_total + timed_benchmark_count += 1 except subprocess.CalledProcessError as exc: results[str(rel)] = False print_validation_error(reporter, rel, exc) @@ -131,6 +157,12 @@ def main(): print(separator) print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL) print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL) + print_average_pim_pass_timings( + pass_timing_sums, + pass_timing_counts, + total_timing_sum, + timed_benchmark_count, + ) sys.exit(0 if n_passed == n_total else 1) diff --git a/validation/validate_one.py b/validation/validate_one.py index 9b16991..3bca810 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -4,6 +4,7 @@ import numpy as np import subprocess import shutil import sys +from dataclasses import dataclass, field from pathlib import Path from colorama import Style, Fore from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP @@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES) GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") +@dataclass +class ValidationResult: + passed: bool + pim_pass_timings: dict[str, float] = field(default_factory=dict) + + class ProgressReporter: def __init__(self, total_models, stages_per_model=STAGE_COUNT): self.total_models = total_models @@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL + f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}") failed_with_exception = False + pim_pass_timings = {} try: print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX") @@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.advance() print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") - compile_with_raptor( + pim_pass_timings = compile_with_raptor( network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count, cwd=raptor_dir, reporter=reporter) @@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.record_result(passed) status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) - return passed + return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings) except Exception: failed_with_exception = True reporter.record_result(False) @@ -352,4 +360,4 @@ if __name__ == '__main__': passed = validate_network( a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir ) - raise SystemExit(0 if passed else 1) + raise SystemExit(0 if passed.passed else 1)