fix gemm segfault
print exit signals on validation failure
This commit is contained in:
@@ -164,7 +164,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||||
|
|
||||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape()))
|
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||||
// Not a gemv
|
// Not a gemv
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import pty
|
|||||||
import selectors
|
import selectors
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
MAX_ERROR_OUTPUT_BYTES = 8192
|
||||||
|
|
||||||
|
|
||||||
def _read_chunk(fd, treat_eio_as_eof=False):
|
def _read_chunk(fd, treat_eio_as_eof=False):
|
||||||
try:
|
try:
|
||||||
@@ -16,6 +18,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
|||||||
|
|
||||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||||
selector = selectors.DefaultSelector()
|
selector = selectors.DefaultSelector()
|
||||||
|
recent_output = bytearray()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selector.register(fd, selectors.EVENT_READ)
|
selector.register(fd, selectors.EVENT_READ)
|
||||||
@@ -31,12 +34,15 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
|||||||
reporter._clear()
|
reporter._clear()
|
||||||
os.write(1, data)
|
os.write(1, data)
|
||||||
reporter._render()
|
reporter._render()
|
||||||
|
recent_output.extend(data)
|
||||||
|
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||||
|
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||||
finally:
|
finally:
|
||||||
selector.close()
|
selector.close()
|
||||||
|
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise subprocess.CalledProcessError(return_code, process.args)
|
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||||
|
|
||||||
|
|
||||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import shlex
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -8,6 +10,37 @@ from colorama import Style, Fore
|
|||||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||||
|
|
||||||
|
|
||||||
|
def format_command(cmd):
|
||||||
|
if isinstance(cmd, (list, tuple)):
|
||||||
|
return shlex.join(str(arg) for arg in cmd)
|
||||||
|
return str(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def format_return_status(returncode):
|
||||||
|
if returncode < 0:
|
||||||
|
signal_num = -returncode
|
||||||
|
try:
|
||||||
|
signal_name = signal.Signals(signal_num).name
|
||||||
|
except ValueError:
|
||||||
|
signal_name = "UNKNOWN"
|
||||||
|
return f"Program terminated by signal {signal_name} ({signal_num})."
|
||||||
|
return f"Program exited with code {returncode}."
|
||||||
|
|
||||||
|
|
||||||
|
def print_validation_error(reporter, rel, exc):
|
||||||
|
reporter.suspend()
|
||||||
|
print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL,
|
||||||
|
file=sys.stderr, flush=True)
|
||||||
|
if isinstance(exc, subprocess.CalledProcessError):
|
||||||
|
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||||
|
print("Retry command:", file=sys.stderr, flush=True)
|
||||||
|
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
||||||
|
else:
|
||||||
|
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||||
|
print("=" * 72, file=sys.stderr, flush=True)
|
||||||
|
reporter.resume()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||||
@@ -70,8 +103,12 @@ def main():
|
|||||||
model_total=len(onnx_files),
|
model_total=len(onnx_files),
|
||||||
)
|
)
|
||||||
results[str(rel)] = passed
|
results[str(rel)] = passed
|
||||||
except (subprocess.CalledProcessError, Exception):
|
except subprocess.CalledProcessError as exc:
|
||||||
results[str(rel)] = False
|
results[str(rel)] = False
|
||||||
|
print_validation_error(reporter, rel, exc)
|
||||||
|
except Exception as exc:
|
||||||
|
results[str(rel)] = False
|
||||||
|
print_validation_error(reporter, rel, exc)
|
||||||
|
|
||||||
reporter.finish()
|
reporter.finish()
|
||||||
|
|
||||||
|
|||||||
@@ -73,10 +73,10 @@ class ProgressReporter:
|
|||||||
def suspend(self):
|
def suspend(self):
|
||||||
self.suspended = True
|
self.suspended = True
|
||||||
self._clear()
|
self._clear()
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
def resume(self):
|
def resume(self):
|
||||||
self.suspended = False
|
self.suspended = False
|
||||||
self._render()
|
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
@@ -227,6 +227,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
|
|
||||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||||
|
failed_with_exception = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||||
@@ -287,10 +288,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||||
return passed
|
return passed
|
||||||
except Exception:
|
except Exception:
|
||||||
|
failed_with_exception = True
|
||||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||||
|
reporter.suspend()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
reporter.log("=" * 72)
|
if not failed_with_exception:
|
||||||
|
reporter.log("=" * 72)
|
||||||
if owns_reporter:
|
if owns_reporter:
|
||||||
reporter.finish()
|
reporter.finish()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user