#!/usr/bin/env python3 """Check whether a PIM bufferized IR dump matches the YOLO depth_35 tail signature.""" from __future__ import annotations import argparse import re from pathlib import Path NONZERO_BOX_SUBVIEW_RE = re.compile( r"memref\.subview .*?: .* to " r"memref<1x2x8400xf32, strided<\[[0-9]+, 8400, 1\], offset: 16800>>" ) FINAL_CONCAT_RE = re.compile( r"pim\.concat axis 1 .*?: " r"\(memref<1x4x8400xf32(?:, strided<[^>]+>)?>, " r"memref<1x80x8400xf32(?:, strided<[^>]+>)?>\) -> memref<1x84x8400xf32>" ) FINAL_OUTPUT_RE = re.compile(r"memref<1x84x8400xf32>") def resolve_pim1_buff(path_str: str) -> Path: path = Path(path_str) if path.is_dir(): candidate = path / "pim1_buff.mlir" if candidate.is_file(): return candidate candidate = path / "dialects" / "pim1_buff.mlir" if candidate.is_file(): return candidate if path.is_file(): return path raise FileNotFoundError(f"could not find pim1_buff.mlir under {path}") def find_line_offsets(text: str) -> list[int]: offsets = [0] for match in re.finditer(r"\n", text): offsets.append(match.end()) return offsets def line_number(offsets: list[int], position: int) -> int: lo = 0 hi = len(offsets) while lo + 1 < hi: mid = (lo + hi) // 2 if offsets[mid] <= position: lo = mid else: hi = mid return lo + 1 def extract_block(lines: list[str], center_line: int, radius: int = 1) -> str: start = max(center_line - radius - 1, 0) end = min(center_line + radius, len(lines)) return "\n".join(lines[start:end]) def check_signature(path: Path) -> dict[str, object]: text = path.read_text() lines = text.splitlines() offsets = find_line_offsets(text) subview_match = NONZERO_BOX_SUBVIEW_RE.search(text) subview_var = None subview_line = None subview_snippet = None if subview_match: line_no = line_number(offsets, subview_match.start()) subview_line = line_no subview_snippet = extract_block(lines, line_no, radius=0) line_text = lines[line_no - 1] lhs = line_text.split("=", 1)[0].strip() subview_var = lhs direct_feed_matches: list[tuple[str, int, str]] = [] if subview_var: for idx, line in enumerate(lines, start=1): if subview_var not in line: continue if not re.search(r"pim\.vv(add|sub|mul)|pim\.concat", line): continue direct_feed_matches.append((line.strip(), idx, extract_block(lines, idx, radius=0))) subview_used_as_dest = any(re.search(rf"pim\.vv(add|sub|mul)\([^)]*, [^)]*, {re.escape(subview_var)}\)", line) for line, _, _ in direct_feed_matches) if subview_var else False final_concat_match = FINAL_CONCAT_RE.search(text) final_output_shape = bool(FINAL_OUTPUT_RE.search(text)) return { "path": path, "has_nonzero_box_subview": bool(subview_match), "nonzero_box_subview_line": subview_line, "nonzero_box_subview_snippet": subview_snippet, "subview_var": subview_var, "direct_pim_uses": direct_feed_matches, "subview_used_as_dest": subview_used_as_dest, "has_final_output_concat": bool(final_concat_match), "has_final_output_shape": final_output_shape, } def print_report(result: dict[str, object]) -> None: path = result["path"] print(f"== {path} ==") print(f"nonzero_box_subview: {result['has_nonzero_box_subview']}") if result["nonzero_box_subview_snippet"]: print(result["nonzero_box_subview_snippet"]) direct_uses = result["direct_pim_uses"] print(f"direct_pim_use_count: {len(direct_uses)}") for line, line_no, snippet in direct_uses: print(f"line {line_no}: {line}") print(snippet) print(f"subview_used_as_destination: {result['subview_used_as_dest']}") print(f"final_output_concat_4_80_to_84: {result['has_final_output_concat']}") print(f"contains_output_shape_1x84x8400: {result['has_final_output_shape']}") structurally_equivalent = ( result["has_nonzero_box_subview"] and bool(direct_uses) and result["has_final_output_concat"] and result["has_final_output_shape"] ) print(f"structurally_equivalent_to_yolo_tail: {structurally_equivalent}") print() def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("paths", nargs="+", help="Workspace, dialect dir, or pim1_buff.mlir to inspect") args = parser.parse_args() for path_str in args.paths: result = check_signature(resolve_pim1_buff(path_str)) print_report(result) return 0 if __name__ == "__main__": raise SystemExit(main())