huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled

remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
NiccoloN
2026-05-12 10:35:44 +02:00
parent feaff820e1
commit 909c4acfdd
84 changed files with 4048 additions and 3310 deletions
+3 -5
View File
@@ -33,8 +33,6 @@ def _parse_pim_pass_timings(output_text):
pass_timings[label] = pass_timings.get(label, 0.0) + duration
break
if not pass_timings:
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
return pass_timings
@@ -43,7 +41,7 @@ def _format_command(cmd):
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None):
# Define the arguments, with the possibility to set crossbar size and count
args = [
network_path,
@@ -51,13 +49,13 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
output_base,
"--maccel=PIM",
"--EmitPimCodegen",
# "--use-experimental-conv-impl=true",
f"--crossbar-size={crossbar_size}",
f"--crossbar-count={crossbar_count}",
"--enable-timing",
]
if core_count is not None:
args.append(f"--core-count={core_count}")
if verbose:
args.append("--enable-timing")
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
if reporter is not None:
+7 -5
View File
@@ -47,7 +47,9 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=
return_code = process.wait()
if return_code != 0:
error_output = captured_output if not stream_output else recent_output
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
exc = subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
exc.output_already_streamed = stream_output and bool(captured_output)
raise exc
return bytes(captured_output)
@@ -67,15 +69,15 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
stream_output = bool(getattr(reporter, "verbose", False))
if not stream_output:
process = subprocess.Popen(
completed = subprocess.run(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False)
return output.decode("utf-8", errors="replace") if capture_output else None
if completed.returncode != 0:
raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout)
return completed.stdout.decode("utf-8", errors="replace") if capture_output else None
try:
master_fd, slave_fd = pty.openpty()
+10 -7
View File
@@ -27,7 +27,9 @@ def print_validation_error(reporter, rel, exc):
file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
if exc.output:
if getattr(exc, "output_already_streamed", False):
print("Failure log already printed above.", file=sys.stderr, flush=True)
elif exc.output:
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
if output_text:
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
@@ -160,12 +162,13 @@ def main():
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |")
print(separator)
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
total_timing_sum,
timed_benchmark_count,
)
if a.verbose:
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
total_timing_sum,
timed_benchmark_count,
)
sys.exit(0 if n_passed == n_total else 1)
+28 -23
View File
@@ -1,6 +1,5 @@
import json
import numpy as np
import subprocess
import shutil
import sys
from dataclasses import dataclass, field
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
STAGE_TITLES = (
"Compile ONNX",
"Build Runner",
@@ -48,10 +46,12 @@ class ProgressReporter:
self.verbose = verbose
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
self.rendered_width = 0
def _clear(self):
if self.enabled:
sys.stdout.write("\033[2K\r")
sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
sys.stdout.flush()
def _render(self):
if not self.enabled or self.suspended:
@@ -70,16 +70,16 @@ class ProgressReporter:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
counts = (
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
)
model_counter = ""
label = ""
@@ -92,9 +92,12 @@ class ProgressReporter:
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
label = label[:available_label_width]
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
padded_width = max(self.rendered_width, len(plain_line))
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
sys.stdout.flush()
self.rendered_width = len(plain_line)
def log(self, message="", color=None):
if not self.verbose:
@@ -124,18 +127,19 @@ class ProgressReporter:
self._render()
def suspend(self):
if self.enabled:
self._clear()
self.suspended = True
self._clear()
sys.stdout.flush()
def resume(self):
self.suspended = False
self._render()
def finish(self):
if self.enabled:
self.suspended = True
self._clear()
sys.stdout.flush()
self.rendered_width = 0
def run_command(cmd, cwd=None, reporter=None):
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command(
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
"--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir,
reporter=reporter,
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count, core_count=core_count,
cwd=raptor_dir, reporter=reporter)
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance()