huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
+28
-23
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
@@ -11,7 +10,6 @@ from raptor import compile_with_raptor
|
||||
from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_TITLES = (
|
||||
"Compile ONNX",
|
||||
"Build Runner",
|
||||
@@ -48,10 +46,12 @@ class ProgressReporter:
|
||||
self.verbose = verbose
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
self.suspended = False
|
||||
self.rendered_width = 0
|
||||
|
||||
def _clear(self):
|
||||
if self.enabled:
|
||||
sys.stdout.write("\033[2K\r")
|
||||
sys.stdout.write("\r" + (" " * self.rendered_width) + "\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
def _render(self):
|
||||
if not self.enabled or self.suspended:
|
||||
@@ -70,16 +70,16 @@ class ProgressReporter:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
counts = (
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
model_counter = ""
|
||||
label = ""
|
||||
@@ -92,9 +92,12 @@ class ProgressReporter:
|
||||
|
||||
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||
label = label[:available_label_width]
|
||||
|
||||
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||
plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label
|
||||
rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL
|
||||
padded_width = max(self.rendered_width, len(plain_line))
|
||||
sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line))))
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = len(plain_line)
|
||||
|
||||
def log(self, message="", color=None):
|
||||
if not self.verbose:
|
||||
@@ -124,18 +127,19 @@ class ProgressReporter:
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
if self.enabled:
|
||||
self._clear()
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
|
||||
def resume(self):
|
||||
self.suspended = False
|
||||
self._render()
|
||||
|
||||
def finish(self):
|
||||
if self.enabled:
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
sys.stdout.flush()
|
||||
self.rendered_width = 0
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, reporter=None):
|
||||
@@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
||||
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||
run_command(
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator",
|
||||
"--",
|
||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||
cwd=simulator_dir,
|
||||
reporter=reporter,
|
||||
@@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.advance()
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner")
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False)
|
||||
gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c",
|
||||
verbose=False)
|
||||
runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter)
|
||||
print_info(reporter, f"Runner built at {runner_path}")
|
||||
reporter.advance()
|
||||
@@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count, core_count=core_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count,
|
||||
core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user