fix gemm segfault
Some checks failed
Validate Operations / config (push) Successful in 1m27s
Validate Operations / build-mlir-cache (push) Successful in 2h0m37s
Validate Operations / validate (push) Failing after 3m41s

print exit signals on validation failure
This commit is contained in:
NiccoloN
2026-03-20 14:00:16 +01:00
parent bb6dcd38a3
commit dbe646ac0d
4 changed files with 52 additions and 5 deletions

View File

@@ -164,7 +164,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape())) if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv // Not a gemv
return failure(); return failure();

View File

@@ -4,6 +4,8 @@ import pty
import selectors import selectors
import subprocess import subprocess
MAX_ERROR_OUTPUT_BYTES = 8192
def _read_chunk(fd, treat_eio_as_eof=False): def _read_chunk(fd, treat_eio_as_eof=False):
try: try:
@@ -16,6 +18,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False): def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector() selector = selectors.DefaultSelector()
recent_output = bytearray()
try: try:
selector.register(fd, selectors.EVENT_READ) selector.register(fd, selectors.EVENT_READ)
@@ -31,12 +34,15 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
reporter._clear() reporter._clear()
os.write(1, data) os.write(1, data)
reporter._render() reporter._render()
recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
finally: finally:
selector.close() selector.close()
return_code = process.wait() return_code = process.wait()
if return_code != 0: if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args) raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
def run_command_with_reporter(cmd, cwd=None, reporter=None): def run_command_with_reporter(cmd, cwd=None, reporter=None):

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import shlex
import signal
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
@@ -8,6 +10,37 @@ from colorama import Style, Fore
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
def format_command(cmd):
if isinstance(cmd, (list, tuple)):
return shlex.join(str(arg) for arg in cmd)
return str(cmd)
def format_return_status(returncode):
if returncode < 0:
signal_num = -returncode
try:
signal_name = signal.Signals(signal_num).name
except ValueError:
signal_name = "UNKNOWN"
return f"Program terminated by signal {signal_name} ({signal_num})."
return f"Program exited with code {returncode}."
def print_validation_error(reporter, rel, exc):
reporter.suspend()
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
print("Retry command:", file=sys.stderr, flush=True)
print(format_command(exc.cmd), file=sys.stderr, flush=True)
else:
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
print("=" * 72, file=sys.stderr, flush=True)
reporter.resume()
def main(): def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
@@ -70,8 +103,12 @@ def main():
model_total=len(onnx_files), model_total=len(onnx_files),
) )
results[str(rel)] = passed results[str(rel)] = passed
except (subprocess.CalledProcessError, Exception): except subprocess.CalledProcessError as exc:
results[str(rel)] = False results[str(rel)] = False
print_validation_error(reporter, rel, exc)
except Exception as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
reporter.finish() reporter.finish()

View File

@@ -73,10 +73,10 @@ class ProgressReporter:
def suspend(self): def suspend(self):
self.suspended = True self.suspended = True
self._clear() self._clear()
sys.stdout.flush()
def resume(self): def resume(self):
self.suspended = False self.suspended = False
self._render()
def finish(self): def finish(self):
if self.enabled: if self.enabled:
@@ -227,6 +227,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL + reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}") f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False
try: try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX") print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
@@ -287,9 +288,12 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed return passed
except Exception: except Exception:
failed_with_exception = True
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL) reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
reporter.suspend()
raise raise
finally: finally:
if not failed_with_exception:
reporter.log("=" * 72) reporter.log("=" * 72)
if owns_reporter: if owns_reporter:
reporter.finish() reporter.finish()