huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -33,8 +33,6 @@ def _parse_pim_pass_timings(output_text):
|
||||
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||
break
|
||||
|
||||
if not pass_timings:
|
||||
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||
return pass_timings
|
||||
|
||||
|
||||
@@ -43,7 +41,7 @@ def _format_command(cmd):
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
|
||||
crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
|
||||
# Define the arguments, with the possibility to set crossbar size and count
|
||||
args = [
|
||||
network_path,
|
||||
@@ -51,13 +49,13 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
output_base,
|
||||
"--maccel=PIM",
|
||||
"--EmitPimCodegen",
|
||||
# "--use-experimental-conv-impl=true",
|
||||
f"--crossbar-size={crossbar_size}",
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
"--enable-timing",
|
||||
]
|
||||
if core_count is not None:
|
||||
args.append(f"--core-count={core_count}")
|
||||
if verbose:
|
||||
args.append("--enable-timing")
|
||||
|
||||
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
|
||||
if reporter is not None:
|
||||
|
||||
@@ -47,7 +47,9 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
error_output = captured_output if not stream_output else recent_output
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
|
||||
exc = subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
|
||||
exc.output_already_streamed = stream_output and bool(captured_output)
|
||||
raise exc
|
||||
return bytes(captured_output)
|
||||
|
||||
|
||||
@@ -67,15 +69,15 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
||||
|
||||
stream_output = bool(getattr(reporter, "verbose", False))
|
||||
if not stream_output:
|
||||
process = subprocess.Popen(
|
||||
completed = subprocess.run(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False)
|
||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||
if completed.returncode != 0:
|
||||
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
|
||||
return completed.stdout.decode("utf-8", errors="replace") if capture_output else None
|
||||
|
||||
try:
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
|
||||
+10
-7
@@ -27,7 +27,9 @@ def print_validation_error(reporter, rel, exc):
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
if exc.output:
|
||||
if getattr(exc, "output_already_streamed", False):
|
||||
print("Failure log already printed above.", file=sys.stderr, flush=True)
|
||||
elif exc.output:
|
||||
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
|
||||
if output_text:
|
||||
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
|
||||
@@ -160,12 +162,13 @@ def main():
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
if a.verbose:
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
total_timing_sum,
|
||||
timed_benchmark_count,
|
||||
)
|
||||
|
||||
sys.exit(0 if n_passed == n_total else 1)
|
||||
|
||||
|
||||
+28
-23
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
|
||||
from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_TITLES = (
|
||||
"Compile ONNX",
|
||||
"Build Runner",
|
||||
@@ -48,10 +46,12 @@ class ProgressReporter:
|
||||
self.verbose = verbose
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
self.suspended = False
|
||||
self.rendered_width = 0
|
||||
|
||||
def _clear(self):
|
||||
if self.enabled:
|
||||
sys.stdout.write("\033[2K\r")
|
||||
sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
def _render(self):
|
||||
if not self.enabled or self.suspended:
|
||||
@@ -70,16 +70,16 @@ class ProgressReporter:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
counts = (
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
model_counter = ""
|
||||
label = ""
|
||||
@@ -92,9 +92,12 @@ class ProgressReporter:
|
||||
|
||||
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||
label = label[:available_label_width]
|
||||
|
||||
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
|
||||
rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
|
||||
padded_width = max(self.rendered_width, len(plain_line))
|
||||
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = len(plain_line)
|
||||
|
||||
def log(self, message="", color=None):
|
||||
if not self.verbose:
|
||||
@@ -124,18 +127,19 @@ class ProgressReporter:
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
if self.enabled:
|
||||
self._clear()
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
def resume(self):
|
||||
self.suspended = False
|
||||
self._render()
|
||||
|
||||
def finish(self):
|
||||
if self.enabled:
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = 0
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, reporter=None):
|
||||
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
||||
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||
run_command(
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
|
||||
"--",
|
||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||
cwd=simulator_dir,
|
||||
reporter=reporter,
|
||||
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
|
||||
verbose=False)
|
||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
|
||||
print_info(reporter, f"Runner built at {runner_path}")
|
||||
reporter.advance()
|
||||
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count, core_count=core_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
|
||||
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user