add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from validate_one import validate_network
|
||||
from validate_one import ProgressReporter, validate_network
|
||||
|
||||
|
||||
def main():
|
||||
@@ -34,32 +35,48 @@ def main():
|
||||
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
|
||||
sys.exit(1)
|
||||
|
||||
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
|
||||
print(f"Operations root: {operations_dir}")
|
||||
print("=" * 72)
|
||||
|
||||
results = {} # relative_path -> passed
|
||||
for onnx_path in onnx_files:
|
||||
reporter = ProgressReporter(len(onnx_files))
|
||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||
rel = onnx_path.relative_to(operations_dir)
|
||||
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}"
|
||||
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
|
||||
try:
|
||||
passed = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
reporter=reporter,
|
||||
model_index=index,
|
||||
model_total=len(onnx_files),
|
||||
)
|
||||
results[str(rel)] = passed
|
||||
except (subprocess.CalledProcessError, Exception):
|
||||
results[str(rel)] = False
|
||||
|
||||
passed = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
)
|
||||
|
||||
results[str(rel)] = passed
|
||||
reporter.finish()
|
||||
|
||||
# Summary
|
||||
n_passed = sum(results.values())
|
||||
n_passed = sum(1 for passed in results.values() if passed)
|
||||
n_total = len(results)
|
||||
print("\n" + Style.BRIGHT + "=" * 60)
|
||||
print(" Summary")
|
||||
print("=" * 60 + Style.RESET_ALL)
|
||||
status_width = len("Result")
|
||||
path_width = max(len("Operation"), *(len(rel) for rel in results))
|
||||
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
|
||||
|
||||
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
|
||||
print(separator)
|
||||
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
|
||||
print(separator)
|
||||
for rel, passed in results.items():
|
||||
status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL"
|
||||
print(f" {rel}: {status}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL)
|
||||
plain_status = "PASS" if passed else "FAIL"
|
||||
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user