add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
@@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name):
|
||||
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
|
||||
"""))
|
||||
|
||||
# Output printing + optional per-output CSV dump
|
||||
out_blocks=[]
|
||||
# Optional per-output CSV dump
|
||||
csv_write_blocks=[]
|
||||
for oi,name,et,shape in outputs:
|
||||
if et not in DTYPES:
|
||||
raise ValueError(f"Unsupported dtype for output '{name}': {et}")
|
||||
cty, pfmt, _ = DTYPES[et]
|
||||
safe = esc(name)
|
||||
out_blocks.append(textwrap.dedent(f"""
|
||||
// ---- Output {oi}: "{safe}" ----
|
||||
{{
|
||||
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
|
||||
int64_t rank = omTensorGetRank(t);
|
||||
int64_t const *shape = omTensorGetShape(t);
|
||||
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
|
||||
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
|
||||
|
||||
printf("Output {oi} ('{safe}'): shape=[");
|
||||
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
|
||||
printf("]\\n");
|
||||
|
||||
if (rank == 2) {{
|
||||
int64_t R = shape[0], C = shape[1];
|
||||
for (int64_t r=0; r<R; ++r) {{
|
||||
for (int64_t c=0; c<C; ++c) {{
|
||||
long long idx = r*C + c;
|
||||
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
|
||||
}}
|
||||
printf("\\n");
|
||||
}}
|
||||
}} else {{
|
||||
// Flattened vector with indices
|
||||
for (long long i=0;i<numel;i++) {{
|
||||
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""))
|
||||
|
||||
# Per-output CSV writer into --save-csv-dir
|
||||
csv_write_blocks.append(textwrap.dedent(f"""
|
||||
if (save_csv_dir) {{
|
||||
// Build "DIR/output{oi}_<sanitized name>.csv"
|
||||
@@ -227,9 +194,6 @@ int main(int argc, char **argv) {{
|
||||
OMTensorList *out_list = {entry}(in_list);
|
||||
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
|
||||
|
||||
// ---- Print full outputs ----
|
||||
{"".join(out_blocks)}
|
||||
|
||||
// ---- Optional per-output CSV dump ----
|
||||
{"".join(csv_write_blocks)}
|
||||
|
||||
@@ -240,7 +204,7 @@ int main(int argc, char **argv) {{
|
||||
}}
|
||||
"""
|
||||
|
||||
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None):
|
||||
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True):
|
||||
ins, outs = onnx_io(network_onnx)
|
||||
out_c = out or "runner.c"
|
||||
so_abs = os.path.abspath(network_so)
|
||||
@@ -260,8 +224,9 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
|
||||
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
|
||||
"""
|
||||
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
|
||||
print(f"[OK] Wrote {out_c}")
|
||||
print("[OK] Wrote CMakeLists.txt")
|
||||
if verbose:
|
||||
print(f"[OK] Wrote {out_c}")
|
||||
print("[OK] Wrote CMakeLists.txt")
|
||||
|
||||
if __name__=="__main__":
|
||||
ap=argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user