fix gemm segfault
Some checks failed
Validate Operations / config (push) Successful in 1m27s
Validate Operations / build-mlir-cache (push) Successful in 2h0m37s
Validate Operations / validate (push) Failing after 3m41s

print exit signals on validation failure
This commit is contained in:
NiccoloN
2026-03-20 14:00:16 +01:00
parent bb6dcd38a3
commit dbe646ac0d
4 changed files with 52 additions and 5 deletions

View File

@@ -4,6 +4,8 @@ import pty
import selectors
import subprocess
MAX_ERROR_OUTPUT_BYTES = 8192
def _read_chunk(fd, treat_eio_as_eof=False):
try:
@@ -16,6 +18,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector()
recent_output = bytearray()
try:
selector.register(fd, selectors.EVENT_READ)
@@ -31,12 +34,15 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
reporter._clear()
os.write(1, data)
reporter._render()
recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
finally:
selector.close()
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args)
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
def run_command_with_reporter(cmd, cwd=None, reporter=None):

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3
import argparse
import shlex
import signal
import subprocess
import sys
from pathlib import Path
@@ -8,6 +10,37 @@ from colorama import Style, Fore
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
def format_command(cmd):
if isinstance(cmd, (list, tuple)):
return shlex.join(str(arg) for arg in cmd)
return str(cmd)
def format_return_status(returncode):
if returncode < 0:
signal_num = -returncode
try:
signal_name = signal.Signals(signal_num).name
except ValueError:
signal_name = "UNKNOWN"
return f"Program terminated by signal {signal_name} ({signal_num})."
return f"Program exited with code {returncode}."
def print_validation_error(reporter, rel, exc):
reporter.suspend()
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
print("Retry command:", file=sys.stderr, flush=True)
print(format_command(exc.cmd), file=sys.stderr, flush=True)
else:
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
print("=" * 72, file=sys.stderr, flush=True)
reporter.resume()
def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
@@ -70,8 +103,12 @@ def main():
model_total=len(onnx_files),
)
results[str(rel)] = passed
except (subprocess.CalledProcessError, Exception):
except subprocess.CalledProcessError as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
except Exception as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
reporter.finish()

View File

@@ -73,10 +73,10 @@ class ProgressReporter:
def suspend(self):
self.suspended = True
self._clear()
sys.stdout.flush()
def resume(self):
self.suspended = False
self._render()
def finish(self):
if self.enabled:
@@ -227,6 +227,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False
try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
@@ -287,10 +288,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
failed_with_exception = True
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
reporter.suspend()
raise
finally:
reporter.log("=" * 72)
if not failed_with_exception:
reporter.log("=" * 72)
if owns_reporter:
reporter.finish()