validate.py now also checks pass timings

This commit is contained in:
NiccoloN
2026-04-16 16:58:44 +02:00
parent ae93d1c563
commit 831b7be4e7
4 changed files with 99 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ import numpy as np
import subprocess
import shutil
import sys
from dataclasses import dataclass, field
from pathlib import Path
from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
@dataclass
class ValidationResult:
passed: bool
pim_pass_timings: dict[str, float] = field(default_factory=dict)
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False
pim_pass_timings = {}
try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor(
pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count,
cwd=raptor_dir, reporter=reporter)
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.record_result(passed)
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
except Exception:
failed_with_exception = True
reporter.record_result(False)
@@ -352,4 +360,4 @@ if __name__ == '__main__':
passed = validate_network(
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
)
raise SystemExit(0 if passed else 1)
raise SystemExit(0 if passed.passed else 1)