#!/usr/bin/env python3 import argparse import shlex import signal import subprocess import sys from pathlib import Path from colorama import Style, Fore from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network from raptor import PIM_PASS_LABELS def format_command(cmd): if isinstance(cmd, (list, tuple)): return shlex.join(str(arg) for arg in cmd) return str(cmd) def format_return_status(returncode): if returncode < 0: signal_num = -returncode try: signal_name = signal.Signals(signal_num).name except ValueError: signal_name = "UNKNOWN" return f"Program terminated by signal {signal_name} ({signal_num})." return f"Program exited with code {returncode}." def print_validation_error(reporter, rel, exc): reporter.suspend() print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL, file=sys.stderr, flush=True) if isinstance(exc, subprocess.CalledProcessError): print(format_return_status(exc.returncode), file=sys.stderr, flush=True) print("Retry command:", file=sys.stderr, flush=True) print(format_command(exc.cmd), file=sys.stderr, flush=True) else: print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True) print("=" * 72, file=sys.stderr, flush=True) reporter.resume() def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count): if timed_benchmark_count == 0: return print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL) for _, label in PIM_PASS_LABELS: count = pass_timing_counts[label] if count == 0: continue print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s") print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s") def main(): ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.") ap.add_argument("--onnx-include-dir", help="Path to OnnxMlirRuntime include directory.") ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).") ap.add_argument("--simulator-dir", default=None, help="Path to pim-simulator crate root (default: auto-detected relative to script).") ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.") ap.add_argument("--crossbar-size", type=int, default=64) ap.add_argument("--crossbar-count", type=int, default=8) ap.add_argument("--clean", action="store_true", help="Remove generated validation artifacts under each model workspace and exit.") a = ap.parse_args() script_dir = Path(__file__).parent.resolve() operations_dir = Path(a.operations_dir).resolve() if a.operations_dir else script_dir / "operations" simulator_dir = Path(a.simulator_dir).resolve() if a.simulator_dir else ( script_dir / ".." / "backend-simulators" / "pim" / "pim-simulator" ) if not operations_dir.is_dir(): print(Fore.RED + f"Operations directory not found: {operations_dir}" + Style.RESET_ALL) sys.exit(1) onnx_files = sorted(operations_dir.rglob("*.onnx")) if not onnx_files: print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL) sys.exit(1) if a.clean: removed_count = 0 for onnx_path in onnx_files: removed_count += len(clean_workspace_artifacts(onnx_path.parent, onnx_path.stem)) print(Style.BRIGHT + f"Removed {removed_count} generated artifact path(s)." + Style.RESET_ALL) sys.exit(0) missing_args = [] if not a.raptor_path: missing_args.append("--raptor-path") if not a.onnx_include_dir: missing_args.append("--onnx-include-dir") if missing_args: ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args)) print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL) print(f"Operations root: {operations_dir}") print("=" * 72) results = {} # relative_path -> passed pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS} pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS} total_timing_sum = 0.0 timed_benchmark_count = 0 reporter = ProgressReporter(len(onnx_files)) for index, onnx_path in enumerate(onnx_files, start=1): rel = onnx_path.relative_to(operations_dir) try: result = validate_network( onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, threshold=a.threshold, reporter=reporter, model_index=index, model_total=len(onnx_files), ) results[str(rel)] = result.passed if result.pim_pass_timings: benchmark_total = 0.0 for label, duration in result.pim_pass_timings.items(): pass_timing_sums[label] += duration pass_timing_counts[label] += 1 benchmark_total += duration total_timing_sum += benchmark_total timed_benchmark_count += 1 except subprocess.CalledProcessError as exc: results[str(rel)] = False print_validation_error(reporter, rel, exc) except Exception as exc: results[str(rel)] = False print_validation_error(reporter, rel, exc) reporter.finish() # Summary n_passed = sum(1 for passed in results.values() if passed) n_total = len(results) status_width = len("Result") path_width = max(len("Operation"), *(len(rel) for rel in results)) separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+" print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL) print(separator) print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |") print(separator) for rel, passed in results.items(): plain_status = "PASS" if passed else "FAIL" status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \ Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL print(f"| {rel.ljust(path_width)} | {status} |") print(separator) print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL) print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL) print_average_pim_pass_timings( pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count, ) sys.exit(0 if n_passed == n_total else 1) if __name__ == "__main__": main()