From dbe646ac0d9af283cde58533a2cba94fa8e6f951 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 20 Mar 2026 14:00:16 +0100 Subject: [PATCH] fix gemm segfault print exit signals on validation failure --- .../Conversion/ONNXToSpatial/Math/Gemm.cpp | 2 +- validation/subprocess_utils.py | 8 +++- validation/validate.py | 39 ++++++++++++++++++- validation/validate_one.py | 8 +++- 4 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index b3c44c9..53bde28 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -164,7 +164,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); - if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape())) + if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) // Not a gemv return failure(); diff --git a/validation/subprocess_utils.py b/validation/subprocess_utils.py index fb0ad29..c1e68b6 100644 --- a/validation/subprocess_utils.py +++ b/validation/subprocess_utils.py @@ -4,6 +4,8 @@ import pty import selectors import subprocess +MAX_ERROR_OUTPUT_BYTES = 8192 + def _read_chunk(fd, treat_eio_as_eof=False): try: @@ -16,6 +18,7 @@ def _read_chunk(fd, treat_eio_as_eof=False): def _stream_output(fd, process, reporter, treat_eio_as_eof=False): selector = selectors.DefaultSelector() + recent_output = bytearray() try: selector.register(fd, selectors.EVENT_READ) @@ -31,12 +34,15 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False): reporter._clear() os.write(1, data) reporter._render() + recent_output.extend(data) + if len(recent_output) > MAX_ERROR_OUTPUT_BYTES: + del recent_output[:-MAX_ERROR_OUTPUT_BYTES] finally: selector.close() return_code = process.wait() if return_code != 0: - raise subprocess.CalledProcessError(return_code, process.args) + raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output)) def run_command_with_reporter(cmd, cwd=None, reporter=None): diff --git a/validation/validate.py b/validation/validate.py index c959360..3da7d47 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 import argparse +import shlex +import signal import subprocess import sys from pathlib import Path @@ -8,6 +10,37 @@ from colorama import Style, Fore from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network +def format_command(cmd): + if isinstance(cmd, (list, tuple)): + return shlex.join(str(arg) for arg in cmd) + return str(cmd) + + +def format_return_status(returncode): + if returncode < 0: + signal_num = -returncode + try: + signal_name = signal.Signals(signal_num).name + except ValueError: + signal_name = "UNKNOWN" + return f"Program terminated by signal {signal_name} ({signal_num})." + return f"Program exited with code {returncode}." + + +def print_validation_error(reporter, rel, exc): + reporter.suspend() + print(Style.BRIGHT + Fore.RED + f"Exception while validating {rel}" + Style.RESET_ALL, + file=sys.stderr, flush=True) + if isinstance(exc, subprocess.CalledProcessError): + print(format_return_status(exc.returncode), file=sys.stderr, flush=True) + print("Retry command:", file=sys.stderr, flush=True) + print(format_command(exc.cmd), file=sys.stderr, flush=True) + else: + print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True) + print("=" * 72, file=sys.stderr, flush=True) + reporter.resume() + + def main(): ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.") ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.") @@ -70,8 +103,12 @@ def main(): model_total=len(onnx_files), ) results[str(rel)] = passed - except (subprocess.CalledProcessError, Exception): + except subprocess.CalledProcessError as exc: results[str(rel)] = False + print_validation_error(reporter, rel, exc) + except Exception as exc: + results[str(rel)] = False + print_validation_error(reporter, rel, exc) reporter.finish() diff --git a/validation/validate_one.py b/validation/validate_one.py index 9fab85b..d5c2f60 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -73,10 +73,10 @@ class ProgressReporter: def suspend(self): self.suspended = True self._clear() + sys.stdout.flush() def resume(self): self.suspended = False - self._render() def finish(self): if self.enabled: @@ -227,6 +227,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL + f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}") + failed_with_exception = False try: print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX") @@ -287,10 +288,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) return passed except Exception: + failed_with_exception = True reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL) + reporter.suspend() raise finally: - reporter.log("=" * 72) + if not failed_with_exception: + reporter.log("=" * 72) if owns_reporter: reporter.finish()