better progress bar in validate_one.py

This commit is contained in:
NiccoloN
2026-04-14 13:13:56 +02:00
parent eade488d13
commit 525792e545

View File

@@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
STAGE_COUNT = 6
STAGE_TITLES = (
"Compile ONNX",
"Build Runner",
"Generate Inputs",
"Run Reference",
"Compile PIM",
"Run Simulator",
"Compare Outputs",
)
STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
@@ -22,6 +31,8 @@ class ProgressReporter:
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
self.completed_steps = 0
self.passed_models = 0
self.failed_models = 0
self.current_label = ""
self.enabled = True
self.columns = shutil.get_terminal_size((100, 20)).columns
@@ -36,21 +47,42 @@ class ProgressReporter:
return
bar_width = 24
filled = int(bar_width * self.completed_steps / self.total_steps)
counts_text = f"P:{self.passed_models} F:{self.failed_models}"
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
if len(prefix_text) > self.columns:
prefix_text = f"{self.completed_steps}/{self.total_steps}"
label = f" {self.current_label}" if self.current_label else ""
available_label_width = max(0, self.columns - len(prefix_text))
label = label[:available_label_width]
if prefix_text.startswith("["):
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
else:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
counts = (
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
)
model_counter = ""
label = ""
if self.current_label.startswith("[") and "] " in self.current_label:
model_counter, label = self.current_label.split("] ", 1)
model_counter = f" {model_counter}]"
label = f" {label}"
elif self.current_label:
label = f" {self.current_label}"
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
label = label[:available_label_width]
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
sys.stdout.flush()
def log(self, message="", color=None):
@@ -70,6 +102,13 @@ class ProgressReporter:
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
self._render()
def record_result(self, passed):
if passed:
self.passed_models += 1
else:
self.failed_models += 1
self._render()
def suspend(self):
self.suspended = True
self._clear()
@@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem):
def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = {
"Compile ONNX": Fore.BLUE,
"Build Runner": Fore.MAGENTA,
"Generate Inputs": Fore.YELLOW,
"Run Reference": Fore.GREEN,
"Compile PIM": Fore.CYAN,
"Run Simulator": Fore.MAGENTA,
"Compare Outputs": Fore.YELLOW,
STAGE_TITLES[0]: Fore.BLUE,
STAGE_TITLES[1]: Fore.MAGENTA,
STAGE_TITLES[2]: Fore.YELLOW,
STAGE_TITLES[3]: Fore.GREEN,
STAGE_TITLES[4]: Fore.CYAN,
STAGE_TITLES[5]: Fore.MAGENTA,
STAGE_TITLES[6]: Fore.YELLOW,
}
color = stage_colors.get(title, Fore.WHITE)
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
@@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.resume()
reporter.advance()
reporter.record_result(passed)
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
failed_with_exception = True
reporter.record_result(False)
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
reporter.suspend()
raise