validate.py now also checks pass timings
This commit is contained in:
@@ -1,8 +1,40 @@
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from colorama import Fore, Style
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
PIM_PASS_LABELS = (
|
||||
("ONNXToSpatialPass", "ONNX to Spatial"),
|
||||
("SpatialToPimPass", "Spatial to PIM"),
|
||||
("PimBufferizationPass", "Bufferize PIM"),
|
||||
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
||||
("VerificationPass", "Verify PIM"),
|
||||
("EmitPimJsonPass", "Emit PIM JSON"),
|
||||
)
|
||||
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
|
||||
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
|
||||
|
||||
|
||||
def _parse_pim_pass_timings(output_text):
|
||||
pass_timings = {}
|
||||
for line in output_text.splitlines():
|
||||
match = TIMING_LINE_RE.match(line)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
duration = float(match.group(1))
|
||||
pass_name = match.group(2)
|
||||
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
|
||||
if pass_name.endswith(suffix):
|
||||
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||
break
|
||||
|
||||
if not pass_timings:
|
||||
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||
return pass_timings
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||
@@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
# "--use-experimental-conv-impl=true",
|
||||
f"--crossbar-size={crossbar_size}",
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
"--enable-timing",
|
||||
]
|
||||
|
||||
try:
|
||||
run_command_with_reporter(
|
||||
output_text = run_command_with_reporter(
|
||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||
cwd=cwd,
|
||||
reporter=reporter,
|
||||
capture_output=True,
|
||||
)
|
||||
if reporter is None:
|
||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||
return _parse_pim_pass_timings(output_text)
|
||||
except subprocess.CalledProcessError:
|
||||
if reporter is None:
|
||||
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
||||
|
||||
@@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
selector = selectors.DefaultSelector()
|
||||
recent_output = bytearray()
|
||||
captured_output = bytearray()
|
||||
|
||||
try:
|
||||
selector.register(fd, selectors.EVENT_READ)
|
||||
@@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
reporter._clear()
|
||||
os.write(1, data)
|
||||
reporter._render()
|
||||
captured_output.extend(data)
|
||||
recent_output.extend(data)
|
||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||
@@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||
return bytes(captured_output)
|
||||
|
||||
|
||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
|
||||
if reporter is None:
|
||||
if capture_output:
|
||||
completed = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
return completed.stdout.decode("utf-8", errors="replace")
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
@@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
_stream_output(process.stdout.fileno(), process, reporter)
|
||||
return
|
||||
output = _stream_output(process.stdout.fileno(), process, reporter)
|
||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
@@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
||||
finally:
|
||||
os.close(slave_fd)
|
||||
|
||||
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
||||
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||
from raptor import PIM_PASS_LABELS
|
||||
|
||||
|
||||
def format_command(cmd):
|
||||
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
|
||||
reporter.resume()
|
||||
|
||||
|
||||
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
|
||||
if timed_benchmark_count == 0:
|
||||
return
|
||||
|
||||
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
|
||||
for _, label in PIM_PASS_LABELS:
|
||||
count = pass_timing_counts[label]
|
||||
if count == 0:
|
||||
continue
|
||||
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
|
||||
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||
@@ -90,11 +104,15 @@ def main():
|
||||
print("=" * 72)
|
||||
|
||||
results = {} # relative_path -> passed
|
||||
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
|
||||
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
|
||||
total_timing_sum = 0.0
|
||||
timed_benchmark_count = 0
|
||||
reporter = ProgressReporter(len(onnx_files))
|
||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||
rel = onnx_path.relative_to(operations_dir)
|
||||
try:
|
||||
passed = validate_network(
|
||||
result = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
threshold=a.threshold,
|
||||
@@ -102,7 +120,15 @@ def main():
|
||||
model_index=index,
|
||||
model_total=len(onnx_files),
|
||||
)
|
||||
results[str(rel)] = passed
|
||||
results[str(rel)] = result.passed
|
||||
if result.pim_pass_timings:
|
||||
benchmark_total = 0.0
|
||||
for label, duration in result.pim_pass_timings.items():
|
||||
pass_timing_sums[label] += duration
|
||||
pass_timing_counts[label] += 1
|
||||
benchmark_total += duration
|
||||
total_timing_sum += benchmark_total
|
||||
timed_benchmark_count += 1
|
||||
except subprocess.CalledProcessError as exc:
|
||||
results[str(rel)] = False
|
||||
print_validation_error(reporter, rel, exc)
|
||||
@@ -131,6 +157,12 @@ def main():
|
||||
print(separator)
|
||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from colorama import Style, Fore
|
||||
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
||||
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
passed: bool
|
||||
pim_pass_timings: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||
self.total_models = total_models
|
||||
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||
failed_with_exception = False
|
||||
pim_pass_timings = {}
|
||||
|
||||
try:
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
compile_with_raptor(
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.record_result(passed)
|
||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.record_result(False)
|
||||
@@ -352,4 +360,4 @@ if __name__ == '__main__':
|
||||
passed = validate_network(
|
||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
||||
)
|
||||
raise SystemExit(0 if passed else 1)
|
||||
raise SystemExit(0 if passed.passed else 1)
|
||||
|
||||
Reference in New Issue
Block a user