add validation artifacts cleanup
This commit is contained in:
@@ -13,6 +13,7 @@ from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_COUNT = 6
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
@@ -88,6 +89,27 @@ def run_command(cmd, cwd=None, reporter=None):
|
||||
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
|
||||
|
||||
|
||||
def clean_workspace_artifacts(workspace_dir, model_stem):
|
||||
workspace_dir = Path(workspace_dir)
|
||||
removed_paths = []
|
||||
|
||||
def remove_path(path):
|
||||
if path.is_symlink() or path.is_file():
|
||||
path.unlink(missing_ok=True)
|
||||
removed_paths.append(path)
|
||||
elif path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
removed_paths.append(path)
|
||||
|
||||
for name in GENERATED_DIR_NAMES:
|
||||
remove_path(workspace_dir / name)
|
||||
|
||||
for suffix in (".onnx.mlir", ".so", ".tmp"):
|
||||
remove_path(workspace_dir / f"{model_stem}{suffix}")
|
||||
|
||||
return removed_paths
|
||||
|
||||
|
||||
def print_stage(reporter, model_index, model_total, model_name, title):
|
||||
stage_colors = {
|
||||
"Compile ONNX": Fore.BLUE,
|
||||
@@ -108,19 +130,15 @@ def print_info(reporter, message):
|
||||
|
||||
|
||||
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
|
||||
run_command([raptor_path, network_onnx_path, "--EmitONNXIR"], reporter=reporter)
|
||||
run_command([raptor_path, network_onnx_path], reporter=reporter)
|
||||
parent = network_onnx_path.parent
|
||||
stem = network_onnx_path.stem
|
||||
so_path = parent / f"{stem}.so"
|
||||
mlir_path = parent / f"{stem}.onnx.mlir"
|
||||
tmp_path = parent / f"{stem}.tmp"
|
||||
moved_so = runner_dir / so_path.name
|
||||
moved_mlir = raptor_dir / mlir_path.name
|
||||
so_path.rename(moved_so)
|
||||
mlir_path.rename(moved_mlir)
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
return moved_so, moved_mlir
|
||||
onnx_ir_base = raptor_dir / stem
|
||||
runner_base = runner_dir / stem
|
||||
run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
|
||||
run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
|
||||
network_so_path = runner_base.with_suffix(".so")
|
||||
network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
|
||||
onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
|
||||
return network_so_path, network_mlir_path
|
||||
|
||||
|
||||
def build_onnx_runner(source_dir, build_dir, reporter=None):
|
||||
@@ -200,6 +218,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter = reporter or ProgressReporter(model_total)
|
||||
|
||||
workspace_dir = network_onnx_path.parent
|
||||
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
|
||||
raptor_dir = workspace_dir / "raptor"
|
||||
runner_dir = workspace_dir / "runner"
|
||||
runner_build_dir = runner_dir / "build"
|
||||
@@ -241,7 +260,9 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
compile_with_raptor(
|
||||
network_mlir_path, raptor_path, crossbar_size, crossbar_count, reporter=reporter)
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user