better progress bar in validate_one.py
This commit is contained in:
@@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_COUNT = 6
|
||||
STAGE_TITLES = (
|
||||
"Compile ONNX",
|
||||
"Build Runner",
|
||||
"Generate Inputs",
|
||||
"Run Reference",
|
||||
"Compile PIM",
|
||||
"Run Simulator",
|
||||
"Compare Outputs",
|
||||
)
|
||||
STAGE_COUNT = len(STAGE_TITLES)
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
@@ -22,6 +31,8 @@ class ProgressReporter:
|
||||
self.stages_per_model = stages_per_model
|
||||
self.total_steps = max(1, total_models * stages_per_model)
|
||||
self.completed_steps = 0
|
||||
self.passed_models = 0
|
||||
self.failed_models = 0
|
||||
self.current_label = ""
|
||||
self.enabled = True
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
@@ -36,21 +47,42 @@ class ProgressReporter:
|
||||
return
|
||||
bar_width = 24
|
||||
filled = int(bar_width * self.completed_steps / self.total_steps)
|
||||
counts_text = f"P:{self.passed_models} F:{self.failed_models}"
|
||||
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
|
||||
if len(prefix_text) > self.columns:
|
||||
prefix_text = f"{self.completed_steps}/{self.total_steps}"
|
||||
|
||||
label = f" {self.current_label}" if self.current_label else ""
|
||||
available_label_width = max(0, self.columns - len(prefix_text))
|
||||
label = label[:available_label_width]
|
||||
|
||||
if prefix_text.startswith("["):
|
||||
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
|
||||
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
|
||||
else:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
|
||||
counts = (
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
model_counter = ""
|
||||
label = ""
|
||||
if self.current_label.startswith("[") and "] " in self.current_label:
|
||||
model_counter, label = self.current_label.split("] ", 1)
|
||||
model_counter = f" {model_counter}]"
|
||||
label = f" {label}"
|
||||
elif self.current_label:
|
||||
label = f" {self.current_label}"
|
||||
|
||||
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||
label = label[:available_label_width]
|
||||
|
||||
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||
sys.stdout.flush()
|
||||
|
||||
def log(self, message="", color=None):
|
||||
@@ -70,6 +102,13 @@ class ProgressReporter:
|
||||
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
|
||||
self._render()
|
||||
|
||||
def record_result(self, passed):
|
||||
if passed:
|
||||
self.passed_models += 1
|
||||
else:
|
||||
self.failed_models += 1
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
@@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem):
|
||||
|
||||
def print_stage(reporter, model_index, model_total, model_name, title):
|
||||
stage_colors = {
|
||||
"Compile ONNX": Fore.BLUE,
|
||||
"Build Runner": Fore.MAGENTA,
|
||||
"Generate Inputs": Fore.YELLOW,
|
||||
"Run Reference": Fore.GREEN,
|
||||
"Compile PIM": Fore.CYAN,
|
||||
"Run Simulator": Fore.MAGENTA,
|
||||
"Compare Outputs": Fore.YELLOW,
|
||||
STAGE_TITLES[0]: Fore.BLUE,
|
||||
STAGE_TITLES[1]: Fore.MAGENTA,
|
||||
STAGE_TITLES[2]: Fore.YELLOW,
|
||||
STAGE_TITLES[3]: Fore.GREEN,
|
||||
STAGE_TITLES[4]: Fore.CYAN,
|
||||
STAGE_TITLES[5]: Fore.MAGENTA,
|
||||
STAGE_TITLES[6]: Fore.YELLOW,
|
||||
}
|
||||
color = stage_colors.get(title, Fore.WHITE)
|
||||
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
|
||||
@@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||
reporter.resume()
|
||||
reporter.advance()
|
||||
reporter.record_result(passed)
|
||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.record_result(False)
|
||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||
reporter.suspend()
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user