This commit is contained in:
@@ -6,15 +6,15 @@ from onnx import TensorProto
|
||||
|
||||
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
||||
DTYPES = {
|
||||
TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"),
|
||||
TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"),
|
||||
TensorProto.INT64: ("int64_t", "%lld","ONNX_TYPE_INT64"),
|
||||
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
||||
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
||||
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
||||
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte
|
||||
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit
|
||||
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
||||
TensorProto.FLOAT: ("float", "%.9g", "ONNX_TYPE_FLOAT"),
|
||||
TensorProto.DOUBLE: ("double", "%.17g", "ONNX_TYPE_DOUBLE"),
|
||||
TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"),
|
||||
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
||||
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
||||
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
||||
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"),
|
||||
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"),
|
||||
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
||||
}
|
||||
|
||||
def esc(s): return s.replace("\\","\\\\").replace('"','\\"')
|
||||
|
||||
Reference in New Issue
Block a user