validate.py now also checks pass timings
This commit is contained in:
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
||||
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
passed: bool
|
||||
pim_pass_timings: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||
self.total_models = total_models
|
||||
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||
failed_with_exception = False
|
||||
pim_pass_timings = {}
|
||||
|
||||
try:
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
compile_with_raptor(
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.record_result(passed)
|
||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.record_result(False)
|
||||
@@ -352,4 +360,4 @@ if __name__ == '__main__':
|
||||
passed = validate_network(
|
||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
||||
)
|
||||
raise SystemExit(0 if passed else 1)
|
||||
raise SystemExit(0 if passed.passed else 1)
|
||||
|
||||
Reference in New Issue
Block a user