validate.py now also checks pass timings

This commit is contained in:
NiccoloN
2026-04-16 16:58:44 +02:00
parent ae93d1c563
commit 831b7be4e7
4 changed files with 99 additions and 11 deletions

View File

@@ -1,8 +1,40 @@
import re
import subprocess
from pathlib import Path
from colorama import Fore, Style
from subprocess_utils import run_command_with_reporter
PIM_PASS_LABELS = (
("ONNXToSpatialPass", "ONNX to Spatial"),
("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"),
("MaterializeHostConstantsPass", "Materialize Host Constants"),
("VerificationPass", "Verify PIM"),
("EmitPimJsonPass", "Emit PIM JSON"),
)
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
def _parse_pim_pass_timings(output_text):
pass_timings = {}
for line in output_text.splitlines():
match = TIMING_LINE_RE.match(line)
if not match:
continue
duration = float(match.group(1))
pass_name = match.group(2)
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
if pass_name.endswith(suffix):
pass_timings[label] = pass_timings.get(label, 0.0) + duration
break
if not pass_timings:
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
return pass_timings
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, cwd=None, reporter=None):
@@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
# "--use-experimental-conv-impl=true",
f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}",
"--enable-timing",
]
try:
run_command_with_reporter(
output_text = run_command_with_reporter(
[str(raptor_onnx_path)] + [str(arg) for arg in args],
cwd=cwd,
reporter=reporter,
capture_output=True,
)
if reporter is None:
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
return _parse_pim_pass_timings(output_text)
except subprocess.CalledProcessError:
if reporter is None:
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)

View File

@@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
selector = selectors.DefaultSelector()
recent_output = bytearray()
captured_output = bytearray()
try:
selector.register(fd, selectors.EVENT_READ)
@@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
reporter._clear()
os.write(1, data)
reporter._render()
captured_output.extend(data)
recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
@@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
return bytes(captured_output)
def run_command_with_reporter(cmd, cwd=None, reporter=None):
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
if reporter is None:
if capture_output:
completed = subprocess.run(
cmd,
cwd=cwd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
return completed.stdout.decode("utf-8", errors="replace")
subprocess.run(cmd, cwd=cwd, check=True)
return
return None
try:
master_fd, slave_fd = pty.openpty()
@@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
_stream_output(process.stdout.fileno(), process, reporter)
return
output = _stream_output(process.stdout.fileno(), process, reporter)
return output.decode("utf-8", errors="replace") if capture_output else None
try:
process = subprocess.Popen(
@@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
finally:
os.close(slave_fd)
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
return output.decode("utf-8", errors="replace") if capture_output else None

View File

@@ -8,6 +8,7 @@ import sys
from pathlib import Path
from colorama import Style, Fore
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
from raptor import PIM_PASS_LABELS
def format_command(cmd):
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
reporter.resume()
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
if timed_benchmark_count == 0:
return
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
for _, label in PIM_PASS_LABELS:
count = pass_timing_counts[label]
if count == 0:
continue
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
def main():
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
@@ -90,11 +104,15 @@ def main():
print("=" * 72)
results = {} # relative_path -> passed
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
total_timing_sum = 0.0
timed_benchmark_count = 0
reporter = ProgressReporter(len(onnx_files))
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir)
try:
passed = validate_network(
result = validate_network(
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
threshold=a.threshold,
@@ -102,7 +120,15 @@ def main():
model_index=index,
model_total=len(onnx_files),
)
results[str(rel)] = passed
results[str(rel)] = result.passed
if result.pim_pass_timings:
benchmark_total = 0.0
for label, duration in result.pim_pass_timings.items():
pass_timing_sums[label] += duration
pass_timing_counts[label] += 1
benchmark_total += duration
total_timing_sum += benchmark_total
timed_benchmark_count += 1
except subprocess.CalledProcessError as exc:
results[str(rel)] = False
print_validation_error(reporter, rel, exc)
@@ -131,6 +157,12 @@ def main():
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
total_timing_sum,
timed_benchmark_count,
)
sys.exit(0 if n_passed == n_total else 1)

View File

@@ -4,6 +4,7 @@ import numpy as np
import subprocess
import shutil
import sys
from dataclasses import dataclass, field
from pathlib import Path
from colorama import Style, Fore
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
@dataclass
class ValidationResult:
passed: bool
pim_pass_timings: dict[str, float] = field(default_factory=dict)
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
self.total_models = total_models
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
failed_with_exception = False
pim_pass_timings = {}
try:
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
compile_with_raptor(
pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count,
cwd=raptor_dir, reporter=reporter)
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.record_result(passed)
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
except Exception:
failed_with_exception = True
reporter.record_result(False)
@@ -352,4 +360,4 @@ if __name__ == '__main__':
passed = validate_network(
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
)
raise SystemExit(0 if passed else 1)
raise SystemExit(0 if passed.passed else 1)