huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled

remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
NiccoloN
2026-05-12 10:35:44 +02:00
parent feaff820e1
commit 909c4acfdd
84 changed files with 4048 additions and 3310 deletions
+28 -23
View File
@@ -1,6 +1,5 @@
import json
import numpy as np
import subprocess
import shutil
import sys
from dataclasses import dataclass, field
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
STAGE_TITLES = (
"Compile ONNX",
"Build Runner",
@@ -48,10 +46,12 @@ class ProgressReporter:
self.verbose = verbose
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
self.rendered_width = 0
def _clear(self):
if self.enabled:
sys.stdout.write("\033[2K\r")
sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
sys.stdout.flush()
def _render(self):
if not self.enabled or self.suspended:
@@ -70,16 +70,16 @@ class ProgressReporter:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
counts = (
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
)
model_counter = ""
label = ""
@@ -92,9 +92,12 @@ class ProgressReporter:
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
label = label[:available_label_width]
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
padded_width = max(self.rendered_width, len(plain_line))
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
sys.stdout.flush()
self.rendered_width = len(plain_line)
def log(self, message="", color=None):
if not self.verbose:
@@ -124,18 +127,19 @@ class ProgressReporter:
self._render()
def suspend(self):
if self.enabled:
self._clear()
self.suspended = True
self._clear()
sys.stdout.flush()
def resume(self):
self.suspended = False
self._render()
def finish(self):
if self.enabled:
self.suspended = True
self._clear()
sys.stdout.flush()
self.rendered_width = 0
def run_command(cmd, cwd=None, reporter=None):
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command(
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
"--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir,
reporter=reporter,
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
reporter.advance()
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
verbose=False)
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
print_info(reporter, f"Runner built at {runner_path}")
reporter.advance()
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
pim_pass_timings = compile_with_raptor(
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
crossbar_size, crossbar_count, core_count=core_count,
cwd=raptor_dir, reporter=reporter)
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
reporter.advance()