validate.py now also checks pass timings
This commit is contained in:
@@ -1,8 +1,40 @@
|
|||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from subprocess_utils import run_command_with_reporter
|
from subprocess_utils import run_command_with_reporter
|
||||||
|
|
||||||
|
PIM_PASS_LABELS = (
|
||||||
|
("ONNXToSpatialPass", "ONNX to Spatial"),
|
||||||
|
("SpatialToPimPass", "Spatial to PIM"),
|
||||||
|
("PimBufferizationPass", "Bufferize PIM"),
|
||||||
|
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||||
|
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
||||||
|
("VerificationPass", "Verify PIM"),
|
||||||
|
("EmitPimJsonPass", "Emit PIM JSON"),
|
||||||
|
)
|
||||||
|
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
|
||||||
|
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pim_pass_timings(output_text):
|
||||||
|
pass_timings = {}
|
||||||
|
for line in output_text.splitlines():
|
||||||
|
match = TIMING_LINE_RE.match(line)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
duration = float(match.group(1))
|
||||||
|
pass_name = match.group(2)
|
||||||
|
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
|
||||||
|
if pass_name.endswith(suffix):
|
||||||
|
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||||
|
break
|
||||||
|
|
||||||
|
if not pass_timings:
|
||||||
|
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||||
|
return pass_timings
|
||||||
|
|
||||||
|
|
||||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||||
@@ -16,16 +48,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
|||||||
# "--use-experimental-conv-impl=true",
|
# "--use-experimental-conv-impl=true",
|
||||||
f"--crossbar-size={crossbar_size}",
|
f"--crossbar-size={crossbar_size}",
|
||||||
f"--crossbar-count={crossbar_count}",
|
f"--crossbar-count={crossbar_count}",
|
||||||
|
"--enable-timing",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run_command_with_reporter(
|
output_text = run_command_with_reporter(
|
||||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
|
capture_output=True,
|
||||||
)
|
)
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||||
|
return _parse_pim_pass_timings(output_text)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
|||||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||||
selector = selectors.DefaultSelector()
|
selector = selectors.DefaultSelector()
|
||||||
recent_output = bytearray()
|
recent_output = bytearray()
|
||||||
|
captured_output = bytearray()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selector.register(fd, selectors.EVENT_READ)
|
selector.register(fd, selectors.EVENT_READ)
|
||||||
@@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
|||||||
reporter._clear()
|
reporter._clear()
|
||||||
os.write(1, data)
|
os.write(1, data)
|
||||||
reporter._render()
|
reporter._render()
|
||||||
|
captured_output.extend(data)
|
||||||
recent_output.extend(data)
|
recent_output.extend(data)
|
||||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||||
@@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
|||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||||
|
return bytes(captured_output)
|
||||||
|
|
||||||
|
|
||||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
|
if capture_output:
|
||||||
|
completed = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=cwd,
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
return completed.stdout.decode("utf-8", errors="replace")
|
||||||
subprocess.run(cmd, cwd=cwd, check=True)
|
subprocess.run(cmd, cwd=cwd, check=True)
|
||||||
return
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
master_fd, slave_fd = pty.openpty()
|
master_fd, slave_fd = pty.openpty()
|
||||||
@@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
|||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
)
|
)
|
||||||
assert process.stdout is not None
|
assert process.stdout is not None
|
||||||
_stream_output(process.stdout.fileno(), process, reporter)
|
output = _stream_output(process.stdout.fileno(), process, reporter)
|
||||||
return
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
@@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
|||||||
finally:
|
finally:
|
||||||
os.close(slave_fd)
|
os.close(slave_fd)
|
||||||
|
|
||||||
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
||||||
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Style, Fore
|
from colorama import Style, Fore
|
||||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||||
|
from raptor import PIM_PASS_LABELS
|
||||||
|
|
||||||
|
|
||||||
def format_command(cmd):
|
def format_command(cmd):
|
||||||
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
|
|||||||
reporter.resume()
|
reporter.resume()
|
||||||
|
|
||||||
|
|
||||||
|
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
|
||||||
|
if timed_benchmark_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
|
||||||
|
for _, label in PIM_PASS_LABELS:
|
||||||
|
count = pass_timing_counts[label]
|
||||||
|
if count == 0:
|
||||||
|
continue
|
||||||
|
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
|
||||||
|
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||||
@@ -90,11 +104,15 @@ def main():
|
|||||||
print("=" * 72)
|
print("=" * 72)
|
||||||
|
|
||||||
results = {} # relative_path -> passed
|
results = {} # relative_path -> passed
|
||||||
|
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
|
||||||
|
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
|
||||||
|
total_timing_sum = 0.0
|
||||||
|
timed_benchmark_count = 0
|
||||||
reporter = ProgressReporter(len(onnx_files))
|
reporter = ProgressReporter(len(onnx_files))
|
||||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||||
rel = onnx_path.relative_to(operations_dir)
|
rel = onnx_path.relative_to(operations_dir)
|
||||||
try:
|
try:
|
||||||
passed = validate_network(
|
result = validate_network(
|
||||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||||
threshold=a.threshold,
|
threshold=a.threshold,
|
||||||
@@ -102,7 +120,15 @@ def main():
|
|||||||
model_index=index,
|
model_index=index,
|
||||||
model_total=len(onnx_files),
|
model_total=len(onnx_files),
|
||||||
)
|
)
|
||||||
results[str(rel)] = passed
|
results[str(rel)] = result.passed
|
||||||
|
if result.pim_pass_timings:
|
||||||
|
benchmark_total = 0.0
|
||||||
|
for label, duration in result.pim_pass_timings.items():
|
||||||
|
pass_timing_sums[label] += duration
|
||||||
|
pass_timing_counts[label] += 1
|
||||||
|
benchmark_total += duration
|
||||||
|
total_timing_sum += benchmark_total
|
||||||
|
timed_benchmark_count += 1
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
results[str(rel)] = False
|
results[str(rel)] = False
|
||||||
print_validation_error(reporter, rel, exc)
|
print_validation_error(reporter, rel, exc)
|
||||||
@@ -131,6 +157,12 @@ def main():
|
|||||||
print(separator)
|
print(separator)
|
||||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||||
|
print_average_pim_pass_timings(
|
||||||
|
pass_timing_sums,
|
||||||
|
pass_timing_counts,
|
||||||
|
total_timing_sum,
|
||||||
|
timed_benchmark_count,
|
||||||
|
)
|
||||||
|
|
||||||
sys.exit(0 if n_passed == n_total else 1)
|
sys.exit(0 if n_passed == n_total else 1)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Style, Fore
|
from colorama import Style, Fore
|
||||||
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
||||||
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
|
|||||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
passed: bool
|
||||||
|
pim_pass_timings: dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ProgressReporter:
|
class ProgressReporter:
|
||||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||||
self.total_models = total_models
|
self.total_models = total_models
|
||||||
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||||
failed_with_exception = False
|
failed_with_exception = False
|
||||||
|
pim_pass_timings = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||||
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||||
compile_with_raptor(
|
pim_pass_timings = compile_with_raptor(
|
||||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||||
crossbar_size, crossbar_count,
|
crossbar_size, crossbar_count,
|
||||||
cwd=raptor_dir, reporter=reporter)
|
cwd=raptor_dir, reporter=reporter)
|
||||||
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.record_result(passed)
|
reporter.record_result(passed)
|
||||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||||
return passed
|
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
|
||||||
except Exception:
|
except Exception:
|
||||||
failed_with_exception = True
|
failed_with_exception = True
|
||||||
reporter.record_result(False)
|
reporter.record_result(False)
|
||||||
@@ -352,4 +360,4 @@ if __name__ == '__main__':
|
|||||||
passed = validate_network(
|
passed = validate_network(
|
||||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
||||||
)
|
)
|
||||||
raise SystemExit(0 if passed else 1)
|
raise SystemExit(0 if passed.passed else 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user