validate.py now also checks pass timings

This commit is contained in:
NiccoloN
2026-04-16 16:58:44 +02:00
parent ae93d1c563
commit 831b7be4e7
4 changed files with 99 additions and 11 deletions

View File

@@ -8,6 +8,7 @@ import sys
from pathlib import Path
from colorama import Style, Fore
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
from raptor import PIM_PASS_LABELS
def format_command(cmd):
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
reporter.resume()
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
if timed_benchmark_count == 0:
return
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
for _, label in PIM_PASS_LABELS:
count = pass_timing_counts[label]
if count == 0:
continue
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
@@ -90,11 +104,15 @@ def main():
print("=" * 72)
results = {} # relative_path -> passed
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
total_timing_sum = 0.0
timed_benchmark_count = 0
reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir)
try:
passed = validate_network(
result = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
@@ -102,7 +120,15 @@ def main():
model_index=index,
model_total=len(onnx_files),
)
results[str(rel)] = passed
results[str(rel)] = result.passed
if result.pim_pass_timings:
benchmark_total = 0.0
for label, duration in result.pim_pass_timings.items():
pass_timing_sums[label] += duration
pass_timing_counts[label] += 1
benchmark_total += duration
total_timing_sum += benchmark_total
timed_benchmark_count += 1
except subprocess.CalledProcessError as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
@@ -131,6 +157,12 @@ def main():
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
total_timing_sum,
timed_benchmark_count,
)
sys.exit(0 if n_passed == n_total else 1)