add validation artifacts cleanup

This commit is contained in:
NiccoloN
2026-03-20 13:15:08 +01:00
parent db3f52a647
commit 916a09414c
3 changed files with 59 additions and 17 deletions

View File

@@ -4,10 +4,13 @@ from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter from subprocess_utils import run_command_with_reporter
def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count, reporter=None): def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, cwd=None, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count # Define the arguments, with the possibility to set crossbar size and count
args = [ args = [
network_path, network_path,
"-o",
output_base,
"--maccel=PIM", "--maccel=PIM",
"--EmitPimCodegen", "--EmitPimCodegen",
# "--use-experimental-conv-impl=true", # "--use-experimental-conv-impl=true",
@@ -18,6 +21,7 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro
try: try:
run_command_with_reporter( run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args], [str(raptor_onnx_path)] + [str(arg) for arg in args],
cwd=cwd,
reporter=reporter, reporter=reporter,
) )
if reporter is None: if reporter is None:

View File

@@ -5,19 +5,21 @@ import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from colorama import Style, Fore from colorama import Style, Fore
from validate_one import ProgressReporter, validate_network from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
def main(): def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", required=True, help="Path to the Raptor compiler binary.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
ap.add_argument("--onnx-include-dir", required=True, help="Path to OnnxMlirRuntime include directory.") ap.add_argument("--onnx-include-dir", help="Path to OnnxMlirRuntime include directory.")
ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).") ap.add_argument("--operations-dir", default=None, help="Root of the operations tree (default: operations).")
ap.add_argument("--simulator-dir", default=None, ap.add_argument("--simulator-dir", default=None,
help="Path to pim-simulator crate root (default: auto-detected relative to script).") help="Path to pim-simulator crate root (default: auto-detected relative to script).")
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.") ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
ap.add_argument("--crossbar-size", type=int, default=64) ap.add_argument("--crossbar-size", type=int, default=64)
ap.add_argument("--crossbar-count", type=int, default=8) ap.add_argument("--crossbar-count", type=int, default=8)
ap.add_argument("--clean", action="store_true",
help="Remove generated validation artifacts under each model workspace and exit.")
a = ap.parse_args() a = ap.parse_args()
script_dir = Path(__file__).parent.resolve() script_dir = Path(__file__).parent.resolve()
@@ -35,6 +37,21 @@ def main():
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL) print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1) sys.exit(1)
if a.clean:
removed_count = 0
for onnx_path in onnx_files:
removed_count += len(clean_workspace_artifacts(onnx_path.parent, onnx_path.stem))
print(Style.BRIGHT + f"Removed {removed_count} generated artifact path(s)." + Style.RESET_ALL)
sys.exit(0)
missing_args = []
if not a.raptor_path:
missing_args.append("--raptor-path")
if not a.onnx_include_dir:
missing_args.append("--onnx-include-dir")
if missing_args:
ap.error("the following arguments are required unless --clean is used: " + ", ".join(missing_args))
print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL) print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL)
print(f"Operations root: {operations_dir}") print(f"Operations root: {operations_dir}")
print("=" * 72) print("=" * 72)

View File

@@ -13,6 +13,7 @@ from subprocess_utils import run_command_with_reporter
STAGE_COUNT = 6 STAGE_COUNT = 6
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
class ProgressReporter: class ProgressReporter:
@@ -88,6 +89,27 @@ def run_command(cmd, cwd=None, reporter=None):
run_command_with_reporter(cmd, cwd=cwd, reporter=reporter) run_command_with_reporter(cmd, cwd=cwd, reporter=reporter)
def clean_workspace_artifacts(workspace_dir, model_stem):
workspace_dir = Path(workspace_dir)
removed_paths = []
def remove_path(path):
if path.is_symlink() or path.is_file():
path.unlink(missing_ok=True)
removed_paths.append(path)
elif path.is_dir():
shutil.rmtree(path)
removed_paths.append(path)
for name in GENERATED_DIR_NAMES:
remove_path(workspace_dir / name)
for suffix in (".onnx.mlir", ".so", ".tmp"):
remove_path(workspace_dir / f"{model_stem}{suffix}")
return removed_paths
def print_stage(reporter, model_index, model_total, model_name, title): def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = { stage_colors = {
"Compile ONNX": Fore.BLUE, "Compile ONNX": Fore.BLUE,
@@ -108,19 +130,15 @@ def print_info(reporter, message):
def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None): def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None):
run_command([raptor_path, network_onnx_path, "--EmitONNXIR"], reporter=reporter)
run_command([raptor_path, network_onnx_path], reporter=reporter)
parent = network_onnx_path.parent
stem = network_onnx_path.stem stem = network_onnx_path.stem
so_path = parent / f"{stem}.so" onnx_ir_base = raptor_dir / stem
mlir_path = parent / f"{stem}.onnx.mlir" runner_base = runner_dir / stem
tmp_path = parent / f"{stem}.tmp" run_command([raptor_path, network_onnx_path, "-o", onnx_ir_base, "--EmitONNXIR"], reporter=reporter)
moved_so = runner_dir / so_path.name run_command([raptor_path, network_onnx_path, "-o", runner_base], reporter=reporter)
moved_mlir = raptor_dir / mlir_path.name network_so_path = runner_base.with_suffix(".so")
so_path.rename(moved_so) network_mlir_path = onnx_ir_base.with_suffix(".onnx.mlir")
mlir_path.rename(moved_mlir) onnx_ir_base.with_suffix(".tmp").unlink(missing_ok=True)
tmp_path.unlink(missing_ok=True) return network_so_path, network_mlir_path
return moved_so, moved_mlir
def build_onnx_runner(source_dir, build_dir, reporter=None): def build_onnx_runner(source_dir, build_dir, reporter=None):
@@ -200,6 +218,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter = reporter or ProgressReporter(model_total) reporter = reporter or ProgressReporter(model_total)
workspace_dir = network_onnx_path.parent workspace_dir = network_onnx_path.parent
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
raptor_dir = workspace_dir / "raptor" raptor_dir = workspace_dir / "raptor"
runner_dir = workspace_dir / "runner" runner_dir = workspace_dir / "runner"
runner_build_dir = runner_dir / "build" runner_build_dir = runner_dir / "build"
@@ -241,7 +260,9 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor( compile_with_raptor(
network_mlir_path, raptor_path, crossbar_size, crossbar_count, reporter=reporter) network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count,
cwd=raptor_dir, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance() reporter.advance()