diff --git a/validation/validate_one.py b/validation/validate_one.py index d5c2f60..9b16991 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner from subprocess_utils import run_command_with_reporter -STAGE_COUNT = 6 +STAGE_TITLES = ( + "Compile ONNX", + "Build Runner", + "Generate Inputs", + "Run Reference", + "Compile PIM", + "Run Simulator", + "Compare Outputs", +) +STAGE_COUNT = len(STAGE_TITLES) GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") @@ -22,6 +31,8 @@ class ProgressReporter: self.stages_per_model = stages_per_model self.total_steps = max(1, total_models * stages_per_model) self.completed_steps = 0 + self.passed_models = 0 + self.failed_models = 0 self.current_label = "" self.enabled = True self.columns = shutil.get_terminal_size((100, 20)).columns @@ -36,21 +47,42 @@ class ProgressReporter: return bar_width = 24 filled = int(bar_width * self.completed_steps / self.total_steps) + counts_text = f"P:{self.passed_models} F:{self.failed_models}" prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}" if len(prefix_text) > self.columns: prefix_text = f"{self.completed_steps}/{self.total_steps}" - label = f" {self.current_label}" if self.current_label else "" - available_label_width = max(0, self.columns - len(prefix_text)) - label = label[:available_label_width] - if prefix_text.startswith("["): bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled)) prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL else: prefix = Fore.CYAN + prefix_text + Style.RESET_ALL - sys.stdout.write("\r" + prefix + label + Style.RESET_ALL) + counts = ( + " " + + Style.BRIGHT + + Fore.GREEN + + f"P:{self.passed_models}" + + Style.RESET_ALL + + " " + + Style.BRIGHT + + Fore.RED + + f"F:{self.failed_models}" + + Style.RESET_ALL + ) + model_counter = "" + label = "" + if self.current_label.startswith("[") and "] " in self.current_label: + model_counter, label = self.current_label.split("] ", 1) + model_counter = f" {model_counter}]" + label = f" {label}" + elif self.current_label: + label = f" {self.current_label}" + + available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3) + label = label[:available_label_width] + + sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL) sys.stdout.flush() def log(self, message="", color=None): @@ -70,6 +102,13 @@ class ProgressReporter: self.completed_steps = min(self.total_steps, self.completed_steps + 1) self._render() + def record_result(self, passed): + if passed: + self.passed_models += 1 + else: + self.failed_models += 1 + self._render() + def suspend(self): self.suspended = True self._clear() @@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem): def print_stage(reporter, model_index, model_total, model_name, title): stage_colors = { - "Compile ONNX": Fore.BLUE, - "Build Runner": Fore.MAGENTA, - "Generate Inputs": Fore.YELLOW, - "Run Reference": Fore.GREEN, - "Compile PIM": Fore.CYAN, - "Run Simulator": Fore.MAGENTA, - "Compare Outputs": Fore.YELLOW, + STAGE_TITLES[0]: Fore.BLUE, + STAGE_TITLES[1]: Fore.MAGENTA, + STAGE_TITLES[2]: Fore.YELLOW, + STAGE_TITLES[3]: Fore.GREEN, + STAGE_TITLES[4]: Fore.CYAN, + STAGE_TITLES[5]: Fore.MAGENTA, + STAGE_TITLES[6]: Fore.YELLOW, } color = stage_colors.get(title, Fore.WHITE) reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL) @@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) reporter.resume() reporter.advance() + reporter.record_result(passed) status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) return passed except Exception: failed_with_exception = True + reporter.record_result(False) reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL) reporter.suspend() raise