fix gemm segfault
print exit signals on validation failure
This commit is contained in:
@@ -164,7 +164,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape()))
|
||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import pty
|
||||
import selectors
|
||||
import subprocess
|
||||
|
||||
MAX_ERROR_OUTPUT_BYTES = 8192
|
||||
|
||||
|
||||
def _read_chunk(fd, treat_eio_as_eof=False):
|
||||
try:
|
||||
@@ -16,6 +18,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
||||
|
||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
selector = selectors.DefaultSelector()
|
||||
recent_output = bytearray()
|
||||
|
||||
try:
|
||||
selector.register(fd, selectors.EVENT_READ)
|
||||
@@ -31,12 +34,15 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
reporter._clear()
|
||||
os.write(1, data)
|
||||
reporter._render()
|
||||
recent_output.extend(data)
|
||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||
finally:
|
||||
selector.close()
|
||||
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
raise subprocess.CalledProcessError(return_code, process.args)
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||
|
||||
|
||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -8,6 +10,37 @@ from colorama import Style, Fore
|
||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||
|
||||
|
||||
def format_command(cmd):
|
||||
if isinstance(cmd, (list, tuple)):
|
||||
return shlex.join(str(arg) for arg in cmd)
|
||||
return str(cmd)
|
||||
|
||||
|
||||
def format_return_status(returncode):
|
||||
if returncode < 0:
|
||||
signal_num = -returncode
|
||||
try:
|
||||
signal_name = signal.Signals(signal_num).name
|
||||
except ValueError:
|
||||
signal_name = "UNKNOWN"
|
||||
return f"Program terminated by signal {signal_name} ({signal_num})."
|
||||
return f"Program exited with code {returncode}."
|
||||
|
||||
|
||||
def print_validation_error(reporter, rel, exc):
|
||||
reporter.suspend()
|
||||
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
print("Retry command:", file=sys.stderr, flush=True)
|
||||
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
||||
else:
|
||||
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||
print("=" * 72, file=sys.stderr, flush=True)
|
||||
reporter.resume()
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||
@@ -70,8 +103,12 @@ def main():
|
||||
model_total=len(onnx_files),
|
||||
)
|
||||
results[str(rel)] = passed
|
||||
except (subprocess.CalledProcessError, Exception):
|
||||
except subprocess.CalledProcessError as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
except Exception as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
|
||||
reporter.finish()
|
||||
|
||||
|
||||
@@ -73,10 +73,10 @@ class ProgressReporter:
|
||||
def suspend(self):
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
def resume(self):
|
||||
self.suspended = False
|
||||
self._render()
|
||||
|
||||
def finish(self):
|
||||
if self.enabled:
|
||||
@@ -227,6 +227,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
|
||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||
failed_with_exception = False
|
||||
|
||||
try:
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||
@@ -287,10 +288,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||
reporter.suspend()
|
||||
raise
|
||||
finally:
|
||||
reporter.log("=" * 72)
|
||||
if not failed_with_exception:
|
||||
reporter.log("=" * 72)
|
||||
if owns_reporter:
|
||||
reporter.finish()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user