Files
Raptor/validation/tools/check_pim_tail_signature.py
ilgeco 75fb70712f
Validate Operations / validate-operations (push) Has been cancelled
CodexWorkaround
2026-06-08 11:33:36 +02:00

145 lines
4.7 KiB
Python

#!/usr/bin/env python3
"""Check whether a PIM bufferized IR dump matches the YOLO depth_35 tail signature."""
from __future__ import annotations
import argparse
import re
from pathlib import Path
NONZERO_BOX_SUBVIEW_RE = re.compile(
r"memref\.subview .*?: .* to "
r"memref<1x2x8400xf32, strided<\[[0-9]+, 8400, 1\], offset: 16800>>"
)
FINAL_CONCAT_RE = re.compile(
r"pim\.concat axis 1 .*?: "
r"\(memref<1x4x8400xf32(?:, strided<[^>]+>)?>, "
r"memref<1x80x8400xf32(?:, strided<[^>]+>)?>\) -> memref<1x84x8400xf32>"
)
FINAL_OUTPUT_RE = re.compile(r"memref<1x84x8400xf32>")
def resolve_pim1_buff(path_str: str) -> Path:
path = Path(path_str)
if path.is_dir():
candidate = path / "pim1_buff.mlir"
if candidate.is_file():
return candidate
candidate = path / "dialects" / "pim1_buff.mlir"
if candidate.is_file():
return candidate
if path.is_file():
return path
raise FileNotFoundError(f"could not find pim1_buff.mlir under {path}")
def find_line_offsets(text: str) -> list[int]:
offsets = [0]
for match in re.finditer(r"\n", text):
offsets.append(match.end())
return offsets
def line_number(offsets: list[int], position: int) -> int:
lo = 0
hi = len(offsets)
while lo + 1 < hi:
mid = (lo + hi) // 2
if offsets[mid] <= position:
lo = mid
else:
hi = mid
return lo + 1
def extract_block(lines: list[str], center_line: int, radius: int = 1) -> str:
start = max(center_line - radius - 1, 0)
end = min(center_line + radius, len(lines))
return "\n".join(lines[start:end])
def check_signature(path: Path) -> dict[str, object]:
text = path.read_text()
lines = text.splitlines()
offsets = find_line_offsets(text)
subview_match = NONZERO_BOX_SUBVIEW_RE.search(text)
subview_var = None
subview_line = None
subview_snippet = None
if subview_match:
line_no = line_number(offsets, subview_match.start())
subview_line = line_no
subview_snippet = extract_block(lines, line_no, radius=0)
line_text = lines[line_no - 1]
lhs = line_text.split("=", 1)[0].strip()
subview_var = lhs
direct_feed_matches: list[tuple[str, int, str]] = []
if subview_var:
for idx, line in enumerate(lines, start=1):
if subview_var not in line:
continue
if not re.search(r"pim\.vv(add|sub|mul)|pim\.concat", line):
continue
direct_feed_matches.append((line.strip(), idx, extract_block(lines, idx, radius=0)))
subview_used_as_dest = any(re.search(rf"pim\.vv(add|sub|mul)\([^)]*, [^)]*, {re.escape(subview_var)}\)", line)
for line, _, _ in direct_feed_matches) if subview_var else False
final_concat_match = FINAL_CONCAT_RE.search(text)
final_output_shape = bool(FINAL_OUTPUT_RE.search(text))
return {
"path": path,
"has_nonzero_box_subview": bool(subview_match),
"nonzero_box_subview_line": subview_line,
"nonzero_box_subview_snippet": subview_snippet,
"subview_var": subview_var,
"direct_pim_uses": direct_feed_matches,
"subview_used_as_dest": subview_used_as_dest,
"has_final_output_concat": bool(final_concat_match),
"has_final_output_shape": final_output_shape,
}
def print_report(result: dict[str, object]) -> None:
path = result["path"]
print(f"== {path} ==")
print(f"nonzero_box_subview: {result['has_nonzero_box_subview']}")
if result["nonzero_box_subview_snippet"]:
print(result["nonzero_box_subview_snippet"])
direct_uses = result["direct_pim_uses"]
print(f"direct_pim_use_count: {len(direct_uses)}")
for line, line_no, snippet in direct_uses:
print(f"line {line_no}: {line}")
print(snippet)
print(f"subview_used_as_destination: {result['subview_used_as_dest']}")
print(f"final_output_concat_4_80_to_84: {result['has_final_output_concat']}")
print(f"contains_output_shape_1x84x8400: {result['has_final_output_shape']}")
structurally_equivalent = (
result["has_nonzero_box_subview"]
and bool(direct_uses)
and result["has_final_output_concat"]
and result["has_final_output_shape"]
)
print(f"structurally_equivalent_to_yolo_tail: {structurally_equivalent}")
print()
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("paths", nargs="+", help="Workspace, dialect dir, or pim1_buff.mlir to inspect")
args = parser.parse_args()
for path_str in args.paths:
result = check_signature(resolve_pim1_buff(path_str))
print_report(result)
return 0
if __name__ == "__main__":
raise SystemExit(main())