validate.py now also checks pass timings
This commit is contained in:
@@ -8,6 +8,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||
from raptor import PIM_PASS_LABELS
|
||||
|
||||
|
||||
def format_command(cmd):
|
||||
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
|
||||
reporter.resume()
|
||||
|
||||
|
||||
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
|
||||
if timed_benchmark_count == 0:
|
||||
return
|
||||
|
||||
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
|
||||
for _, label in PIM_PASS_LABELS:
|
||||
count = pass_timing_counts[label]
|
||||
if count == 0:
|
||||
continue
|
||||
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
|
||||
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||
@@ -90,11 +104,15 @@ def main():
|
||||
print("=" * 72)
|
||||
|
||||
results = {} # relative_path -> passed
|
||||
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
|
||||
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
|
||||
total_timing_sum = 0.0
|
||||
timed_benchmark_count = 0
|
||||
reporter = ProgressReporter(len(onnx_files))
|
||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||
rel = onnx_path.relative_to(operations_dir)
|
||||
try:
|
||||
passed = validate_network(
|
||||
result = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
@@ -102,7 +120,15 @@ def main():
|
||||
model_index=index,
|
||||
model_total=len(onnx_files),
|
||||
)
|
||||
results[str(rel)] = passed
|
||||
results[str(rel)] = result.passed
|
||||
if result.pim_pass_timings:
|
||||
benchmark_total = 0.0
|
||||
for label, duration in result.pim_pass_timings.items():
|
||||
pass_timing_sums[label] += duration
|
||||
pass_timing_counts[label] += 1
|
||||
benchmark_total += duration
|
||||
total_timing_sum += benchmark_total
|
||||
timed_benchmark_count += 1
|
||||
except subprocess.CalledProcessError as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
@@ -131,6 +157,12 @@ def main():
|
||||
print(separator)
|
||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user