140 lines
5.6 KiB
Python
140 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import shlex
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from colorama import Style, Fore
|
|
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
|
|
|
|
|
def format_command(cmd):
|
|
if isinstance(cmd, (list, tuple)):
|
|
return shlex.join(str(arg) for arg in cmd)
|
|
return str(cmd)
|
|
|
|
|
|
def format_return_status(returncode):
|
|
if returncode < 0:
|
|
signal_num = -returncode
|
|
try:
|
|
signal_name = signal.Signals(signal_num).name
|
|
except ValueError:
|
|
signal_name = "UNKNOWN"
|
|
return f"Program terminated by signal {signal_name} ({signal_num})."
|
|
return f"Program exited with code {returncode}."
|
|
|
|
|
|
def print_validation_error(reporter, rel, exc):
|
|
reporter.suspend()
|
|
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
|
|
file=sys.stderr, flush=True)
|
|
if isinstance(exc, subprocess.CalledProcessError):
|
|
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
|
print("Retry command:", file=sys.stderr, flush=True)
|
|
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
|
else:
|
|
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
|
print("=" * 72, file=sys.stderr, flush=True)
|
|
reporter.resume()
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
|
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
|
ap.add_argument("--onnx-include-dir", help="Path to OnnxMlirRuntime include directory.")
|
|
ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).")
|
|
ap.add_argument("--simulator-dir", default=None,
|
|
help="Path to pim-simulator crate root (default: auto-detected relative to script).")
|
|
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
|
|
ap.add_argument("--crossbar-size", type=int, default=64)
|
|
ap.add_argument("--crossbar-count", type=int, default=8)
|
|
ap.add_argument("--clean", action="store_true",
|
|
help="Remove generated validation artifacts under each model workspace and exit.")
|
|
a = ap.parse_args()
|
|
|
|
script_dir = Path(__file__).parent.resolve()
|
|
operations_dir = Path(a.operations_dir).resolve() if a.operations_dir else script_dir / "operations"
|
|
simulator_dir = Path(a.simulator_dir).resolve() if a.simulator_dir else (
|
|
script_dir / ".." / "backend-simulators" / "pim" / "pim-simulator"
|
|
)
|
|
|
|
if not operations_dir.is_dir():
|
|
print(Fore.RED + f"Operations directory not found: {operations_dir}" + Style.RESET_ALL)
|
|
sys.exit(1)
|
|
|
|
onnx_files = sorted(operations_dir.rglob("*.onnx"))
|
|
if not onnx_files:
|
|
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
|
|
sys.exit(1)
|
|
|
|
if a.clean:
|
|
removed_count = 0
|
|
for onnx_path in onnx_files:
|
|
removed_count += len(clean_workspace_artifacts(onnx_path.parent, onnx_path.stem))
|
|
print(Style.BRIGHT + f"Removed {removed_count} generated artifact path(s)." + Style.RESET_ALL)
|
|
sys.exit(0)
|
|
|
|
missing_args = []
|
|
if not a.raptor_path:
|
|
missing_args.append("--raptor-path")
|
|
if not a.onnx_include_dir:
|
|
missing_args.append("--onnx-include-dir")
|
|
if missing_args:
|
|
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
|
|
|
|
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
|
|
print(f"Operations root: {operations_dir}")
|
|
print("=" * 72)
|
|
|
|
results = {} # relative_path -> passed
|
|
reporter = ProgressReporter(len(onnx_files))
|
|
for index, onnx_path in enumerate(onnx_files, start=1):
|
|
rel = onnx_path.relative_to(operations_dir)
|
|
try:
|
|
passed = validate_network(
|
|
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
|
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
|
threshold=a.threshold,
|
|
reporter=reporter,
|
|
model_index=index,
|
|
model_total=len(onnx_files),
|
|
)
|
|
results[str(rel)] = passed
|
|
except subprocess.CalledProcessError as exc:
|
|
results[str(rel)] = False
|
|
print_validation_error(reporter, rel, exc)
|
|
except Exception as exc:
|
|
results[str(rel)] = False
|
|
print_validation_error(reporter, rel, exc)
|
|
|
|
reporter.finish()
|
|
|
|
# Summary
|
|
n_passed = sum(1 for passed in results.values() if passed)
|
|
n_total = len(results)
|
|
status_width = len("Result")
|
|
path_width = max(len("Operation"), *(len(rel) for rel in results))
|
|
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
|
|
|
|
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
|
|
print(separator)
|
|
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
|
|
print(separator)
|
|
for rel, passed in results.items():
|
|
plain_status = "PASS" if passed else "FAIL"
|
|
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
|
|
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
|
print(f"| {rel.ljust(path_width)} | {status} |")
|
|
print(separator)
|
|
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
|
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
|
|
|
sys.exit(0 if n_passed == n_total else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|