276 lines
12 KiB
Python
276 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse, os, pathlib, textwrap
|
|
from onnx_utils import onnx_io
|
|
from onnx import TensorProto
|
|
|
|
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
|
DTYPES = {
|
|
TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"),
|
|
TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"),
|
|
TensorProto.INT64: ("int64_t", "%lld","ONNX_TYPE_INT64"),
|
|
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
|
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
|
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
|
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte
|
|
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit
|
|
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
|
}
|
|
|
|
def esc(s): return s.replace("\\","\\\\").replace('"','\\"')
|
|
|
|
def gen_c(inputs, outputs, entry, so_name):
|
|
in_blocks=[]
|
|
for i,name,et,shape in inputs:
|
|
if et not in DTYPES:
|
|
raise ValueError(f"Unsupported dtype for input '{name}': {et}")
|
|
cty, pfmt, onnx_ty = DTYPES[et]
|
|
shp_list = ", ".join(str(d) for d in shape) if shape else ""
|
|
rank = len(shape)
|
|
in_blocks.append(textwrap.dedent(f"""
|
|
// ---- Input {i}: "{esc(name)}" ({cty}) ----
|
|
const char *in{i}_csv=NULL, *in{i}_csv_file=NULL, *in{i}_shape_str=NULL;
|
|
char *in{i}_csv_buf=NULL; // holds file contents if --in{i}-csv-file used
|
|
int has_in{i}=0; double in{i}_fill=0.0; int in{i}_fill_set=0;
|
|
|
|
for (int ai=1; ai<argc; ++ai) {{
|
|
if (strncmp(argv[ai],"--in{i}-csv",11)==0 && ai+1<argc) {{ in{i}_csv=argv[ai+1]; has_in{i}=1; }}
|
|
else if (strncmp(argv[ai],"--in{i}-csv-file",16)==0 && ai+1<argc) {{ in{i}_csv_file=argv[ai+1]; has_in{i}=1; }}
|
|
else if (strncmp(argv[ai],"--in{i}-fill",12)==0 && ai+1<argc) {{ in{i}_fill=atof(argv[ai+1]); in{i}_fill_set=1; has_in{i}=1; }}
|
|
else if (strncmp(argv[ai],"--in{i}-shape",13)==0 && ai+1<argc) {{ in{i}_shape_str=argv[ai+1]; }}
|
|
}}
|
|
if (!has_in{i}) {{
|
|
fprintf(stderr,"ERROR: provide one of --in{i}-csv/--in{i}-csv-file/--in{i}-fill for input {i}.\\n");
|
|
return 2;
|
|
}}
|
|
|
|
// If a CSV file was provided, read it fully and use its content as the CSV string.
|
|
if (in{i}_csv_file && !in{i}_csv) {{
|
|
FILE *f=fopen(in{i}_csv_file,"rb");
|
|
if(!f){{perror("fopen --in{i}-csv-file"); return 2;}}
|
|
fseek(f, 0, SEEK_END);
|
|
long sz = ftell(f); if (sz < 0) {{ perror("ftell"); fclose(f); return 2; }}
|
|
fseek(f, 0, SEEK_SET);
|
|
in{i}_csv_buf = (char*)malloc((size_t)sz + 1);
|
|
if(!in{i}_csv_buf){{fprintf(stderr,"OOM reading --in{i}-csv-file.\\n"); fclose(f); return 2;}}
|
|
size_t got = fread(in{i}_csv_buf, 1, (size_t)sz, f);
|
|
fclose(f);
|
|
if (got != (size_t)sz) {{ fprintf(stderr,"ERROR: short read for --in{i}-csv-file.\\n"); free(in{i}_csv_buf); return 2; }}
|
|
in{i}_csv_buf[sz] = '\\0';
|
|
in{i}_csv = in{i}_csv_buf;
|
|
}}
|
|
|
|
int64_t *in{i}_shape=NULL; int in{i}_rank=0;
|
|
if (in{i}_shape_str) {{
|
|
char *tmp=strdup(in{i}_shape_str);
|
|
for(char*p=tmp; *p; ++p) if(*p=='x'||*p=='X') in{i}_rank++;
|
|
in{i}_rank++;
|
|
in{i}_shape=(int64_t*)malloc(sizeof(int64_t)*in{i}_rank);
|
|
int di=0; char *tok=strtok(tmp,"xX");
|
|
while(tok && di<in{i}_rank) {{ in{i}_shape[di++]=atoll(tok); tok=strtok(NULL,"xX"); }}
|
|
free(tmp);
|
|
}} else {{
|
|
in{i}_rank={rank};
|
|
in{i}_shape=(int64_t*)malloc(sizeof(int64_t)*in{i}_rank);
|
|
int64_t def_shape[]={{{shp_list}}};
|
|
for(int k=0;k<in{i}_rank;k++) in{i}_shape[k]=def_shape[k];
|
|
}}
|
|
long long in{i}_nelem=1; for(int k=0;k<in{i}_rank;k++) in{i}_nelem*=in{i}_shape[k];
|
|
|
|
size_t in{i}_bytes = sizeof({cty}) * (size_t)in{i}_nelem;
|
|
void *in{i}_buf = malloc(in{i}_bytes);
|
|
if(!in{i}_buf){{fprintf(stderr,"OOM for input {i}.\\n"); if(in{i}_csv_buf) free(in{i}_csv_buf); return 2;}}
|
|
|
|
if (in{i}_csv) {{
|
|
char *buf=strdup(in{i}_csv); long long idx=0; char *tok=strtok(buf,",\\n\\r\\t ");
|
|
while(tok) {{
|
|
if(idx>=in{i}_nelem) break;
|
|
double v=atof(tok);
|
|
(({cty}*)in{i}_buf)[idx++] = ({cty})v;
|
|
tok=strtok(NULL,",\\n\\r\\t ");
|
|
}}
|
|
free(buf);
|
|
if(idx!=in{i}_nelem){{fprintf(stderr,"ERROR: CSV provided %lld values, expected %lld.\\n",(long long)idx,in{i}_nelem); if(in{i}_csv_buf) free(in{i}_csv_buf); return 2;}}
|
|
}} else if (in{i}_fill_set) {{
|
|
{cty} vv=({cty})in{i}_fill; for(long long t=0;t<in{i}_nelem;t++) (({cty}*)in{i}_buf)[t]=vv;
|
|
}} else {{
|
|
fprintf(stderr,"ERROR: no data source for input {i}.\\n"); if(in{i}_csv_buf) free(in{i}_csv_buf); return 2;
|
|
}}
|
|
|
|
OMTensor *in{i}_tensor = omTensorCreateWithOwnership(in{i}_buf, in{i}_shape, in{i}_rank, {onnx_ty}, /*owning=*/1);
|
|
if(in{i}_csv_buf) free(in{i}_csv_buf);
|
|
if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}}
|
|
"""))
|
|
|
|
# Output printing + optional per-output CSV dump
|
|
out_blocks=[]
|
|
csv_write_blocks=[]
|
|
for oi,name,et,shape in outputs:
|
|
if et not in DTYPES:
|
|
raise ValueError(f"Unsupported dtype for output '{name}': {et}")
|
|
cty, pfmt, _ = DTYPES[et]
|
|
safe = esc(name)
|
|
out_blocks.append(textwrap.dedent(f"""
|
|
// ---- Output {oi}: "{safe}" ----
|
|
{{
|
|
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
|
|
int64_t rank = omTensorGetRank(t);
|
|
int64_t const *shape = omTensorGetShape(t);
|
|
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
|
|
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
|
|
|
|
printf("Output {oi} ('{safe}'): shape=[");
|
|
for (int64_t k=0;k<rank;k++) printf("%ld%s",(long)shape[k], (k+1<rank)?",":"");
|
|
printf("]\\n");
|
|
|
|
if (rank == 2) {{
|
|
int64_t R = shape[0], C = shape[1];
|
|
for (int64_t r=0; r<R; ++r) {{
|
|
for (int64_t c=0; c<C; ++c) {{
|
|
long long idx = r*C + c;
|
|
printf("{pfmt}%s", p[idx], (c+1<C)?", ":"");
|
|
}}
|
|
printf("\\n");
|
|
}}
|
|
}} else {{
|
|
// Flattened vector with indices
|
|
for (long long i=0;i<numel;i++) {{
|
|
printf("[%lld]={pfmt}%s", i, p[i], (i+1<numel)?", ":"\\n");
|
|
}}
|
|
}}
|
|
}}
|
|
"""))
|
|
|
|
# Per-output CSV writer into --save-csv-dir
|
|
csv_write_blocks.append(textwrap.dedent(f"""
|
|
if (save_csv_dir) {{
|
|
// Build "DIR/output{oi}_<sanitized name>.csv"
|
|
char fname[512];
|
|
// simple sanitizer: copy name => replace non [A-Za-z0-9_.-] with '_'
|
|
char clean[256]; int ci=0; const char *src="{safe}";
|
|
for (; src[ci] && ci < 255; ++ci) {{
|
|
char ch = src[ci];
|
|
int ok = (ch>='A'&&ch<='Z')||(ch>='a'&&ch<='z')||(ch>='0'&&ch<='9')||ch=='_'||ch=='-'||ch=='.';
|
|
clean[ci] = ok ? ch : '_';
|
|
}}
|
|
clean[ci] = '\\0';
|
|
snprintf(fname, sizeof(fname), "%s/output{oi}_%s.csv", save_csv_dir, clean);
|
|
FILE *csv = fopen(fname, "w");
|
|
if (!csv) {{ perror("fopen --save-csv-dir"); }}
|
|
else {{
|
|
OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi});
|
|
int64_t rank = omTensorGetRank(t);
|
|
int64_t const *shape = omTensorGetShape(t);
|
|
long long numel = 1; for (int64_t k=0;k<rank;k++) numel *= shape[k];
|
|
{cty} *p = ({cty}*)omTensorGetDataPtr(t);
|
|
|
|
if (rank == 2) {{
|
|
int64_t R = shape[0], C = shape[1];
|
|
for (int64_t r=0; r<R; ++r) {{
|
|
for (int64_t c=0; c<C; ++c) {{
|
|
long long idx = r*C + c;
|
|
fprintf(csv, "{pfmt}%s", p[idx], (c+1<C)?",":"");
|
|
}}
|
|
fprintf(csv, "\\n");
|
|
}}
|
|
}} else {{
|
|
for (long long i=0;i<numel;i++) {{
|
|
fprintf(csv, "{pfmt}%s", p[i], (i+1<numel)?",":"");
|
|
}}
|
|
fprintf(csv, "\\n");
|
|
}}
|
|
fclose(csv);
|
|
}}
|
|
}}
|
|
"""))
|
|
|
|
n_in=len(inputs)
|
|
build_inputs="\n".join([f" arr[{i}] = in{i}_tensor;" for i,_,_,_ in inputs])
|
|
|
|
return f"""\
|
|
// Auto-generated onnx network runner
|
|
#include "OnnxMlirRuntime.h"
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include <inttypes.h>
|
|
|
|
OMTensorList *{entry}(OMTensorList *inputs);
|
|
|
|
int main(int argc, char **argv) {{
|
|
// optional: --save-csv-dir <DIR> (directory must exist)
|
|
const char *save_csv_dir = NULL;
|
|
for (int ai=1; ai<argc; ++ai) {{
|
|
if (strncmp(argv[ai], "--save-csv-dir", 14)==0 && ai+1 < argc) {{
|
|
save_csv_dir = argv[ai+1];
|
|
}}
|
|
}}
|
|
|
|
if (argc == 1) {{
|
|
fprintf(stderr,
|
|
"Usage: %s "
|
|
"[--inK-csv \\\"v1,v2,...\\\" | --inK-csv-file path | --inK-fill c] "
|
|
"[--inK-shape 1x...xD] "
|
|
"[--save-csv-dir /path/to/dir]\\n"
|
|
"Repeat for K=0..%d.\\n",
|
|
argv[0], {max(0, n_in-1)});
|
|
return 1;
|
|
}}
|
|
|
|
{"".join(in_blocks)}
|
|
|
|
OMTensor *arr[{n_in}];
|
|
{build_inputs}
|
|
OMTensorList *in_list = omTensorListCreate(arr, {n_in});
|
|
if(!in_list){{fprintf(stderr,"ERROR: omTensorListCreate failed.\\n");return 2;}}
|
|
|
|
OMTensorList *out_list = {entry}(in_list);
|
|
if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}}
|
|
|
|
// ---- Print full outputs ----
|
|
{"".join(out_blocks)}
|
|
|
|
// ---- Optional per-output CSV dump ----
|
|
{"".join(csv_write_blocks)}
|
|
|
|
// ---- Cleanup ----
|
|
omTensorListDestroy(in_list);
|
|
omTensorListDestroy(out_list);
|
|
return 0;
|
|
}}
|
|
"""
|
|
|
|
def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None):
|
|
ins, outs = onnx_io(network_onnx)
|
|
out_c = out or "runner.c"
|
|
so_abs = os.path.abspath(network_so)
|
|
onnx_include_dir = str(onnx_include_dir)
|
|
|
|
csrc = gen_c(ins, outs, entry, pathlib.Path(so_abs).name)
|
|
pathlib.Path(out_c).write_text(csrc)
|
|
|
|
cmake=f"""\
|
|
cmake_minimum_required(VERSION 3.15)
|
|
project(onnx_mlir_runner C)
|
|
add_executable({pathlib.Path(out_c).stem} {pathlib.Path(out_c).name})
|
|
target_include_directories({pathlib.Path(out_c).stem} PUBLIC {esc(onnx_include_dir)})
|
|
|
|
add_library(model_so SHARED IMPORTED)
|
|
set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)})
|
|
target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so)
|
|
"""
|
|
pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake)
|
|
print(f"[OK] Wrote {out_c}")
|
|
print("[OK] Wrote CMakeLists.txt")
|
|
|
|
if __name__=="__main__":
|
|
ap=argparse.ArgumentParser()
|
|
ap.add_argument("--network-onnx", required=True)
|
|
ap.add_argument("--network-so", required=True)
|
|
ap.add_argument("--onnx-include-dir", required=True)
|
|
ap.add_argument("--entry", default="run_main_graph")
|
|
ap.add_argument("--out", default=None)
|
|
a=ap.parse_args()
|
|
|
|
gen_network_runner(a.network_onnx, a.network_so, a.onnx_include_dir, a.entry, a.out)
|