fix relative paths in validation scripts

This commit is contained in:
NiccoloN
2026-02-26 16:34:31 +01:00
parent a42ff74a3b
commit 0acd298e80
3 changed files with 15 additions and 18 deletions

View File

@@ -7,10 +7,6 @@ from colorama import Style, Fore
from validate_one import validate_network
def discover_onnx_files(operations_dir):
return sorted(operations_dir.rglob("*.onnx"))
def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", required=True, help="Path to the Raptor compiler binary.")
@@ -24,8 +20,8 @@ def main():
a = ap.parse_args()
script_dir = Path(__file__).parent.resolve()
operations_dir = Path(a.operations_dir) if a.operations_dir else script_dir / "operations"
simulator_dir = Path(a.simulator_dir) if a.simulator_dir else (
operations_dir = Path(a.operations_dir).resolve() if a.operations_dir else script_dir / "operations"
simulator_dir = Path(a.simulator_dir).resolve() if a.simulator_dir else (
script_dir / ".." / "backend-simulators" / "pim" / "pim-simulator"
)
@@ -33,7 +29,7 @@ def main():
print(Fore.RED + f"Operations directory not found: {operations_dir}" + Style.RESET_ALL)
sys.exit(1)
onnx_files = discover_onnx_files(operations_dir)
onnx_files = sorted(operations_dir.rglob("*.onnx"))
if not onnx_files:
print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL)
sys.exit(1)
@@ -46,15 +42,11 @@ def main():
header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}"
print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL)
try:
passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
)
except Exception as e:
print(Fore.RED + f" ERROR: {e}" + Style.RESET_ALL)
passed = False
passed = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
)
results[str(rel)] = passed