Files
Raptor/validation/validate.py
2026-02-26 16:34:31 +01:00

69 lines
2.7 KiB
Python

#!/usr/bin/env python3
import argparse
import sys
from pathlib import Path
from colorama import Style, Fore
from validate_one import validate_network
def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", required=True, help="Path to the Raptor compiler binary.")
ap.add_argument("--onnx-include-dir", required=True, help="Path to OnnxMlirRuntime include directory.")
ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).")
ap.add_argument("--simulator-dir", default=None,
help="Path to pim-simulator crate root (default: auto-detected relative to script).")
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
ap.add_argument("--crossbar-size", type=int, default=64)
ap.add_argument("--crossbar-count", type=int, default=8)
a = ap.parse_args()
script_dir = Path(__file__).parent.resolve()
operations_dir = Path(a.operations_dir).resolve() if a.operations_dir else script_dir / "operations"
simulator_dir = Path(a.simulator_dir).resolve() if a.simulator_dir else (
script_dir / ".." / "backend-simulators" / "pim" / "pim-simulator"
)
if not operations_dir.is_dir():
print(Fore.RED + f"Operations directory not found: {operations_dir}" + Style.RESET_ALL)
sys.exit(1)
onnx_files = sorted(operations_dir.rglob("*.onnx"))
if not onnx_files:
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1)
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL)
results = {} # relative_path -> passed
for onnx_path in onnx_files:
rel = onnx_path.relative_to(operations_dir)
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}"
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
)
results[str(rel)] = passed
# Summary
n_passed = sum(results.values())
n_total = len(results)
print("\n" + Style.BRIGHT + "=" * 60)
print(" Summary")
print("=" * 60 + Style.RESET_ALL)
for rel, passed in results.items():
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL"
print(f" {rel}: {status}" + Style.RESET_ALL)
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL)
sys.exit(0 if n_passed == n_total else 1)
if __name__ == "__main__":
main()