This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 287ee40 [PYTHON][FFI] Speed Up get DataType (#9072)
287ee40 is described below
commit 287ee40adc7e0d98482a3b71784ccfad732a980e
Author: wangxiang2713 <[email protected]>
AuthorDate: Thu Sep 23 03:19:22 2021 +0800
[PYTHON][FFI] Speed Up get DataType (#9072)
---
python/tvm/_ffi/runtime_ctypes.py | 46 ++++++++++++++++++++++++++++++++++-----
1 file changed, 41 insertions(+), 5 deletions(-)
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 450a356..297e24d 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -72,16 +72,52 @@ class DataType(ctypes.Structure):
DataTypeCode.HANDLE: "handle",
DataTypeCode.BFLOAT: "bfloat",
}
+ NUMPY2STR = {
+ np.dtype(np.bool_): "bool",
+ np.dtype(np.int8): "int8",
+ np.dtype(np.int16): "int16",
+ np.dtype(np.int32): "int32",
+ np.dtype(np.int64): "int64",
+ np.dtype(np.uint8): "uint8",
+ np.dtype(np.uint16): "uint16",
+ np.dtype(np.uint32): "uint32",
+ np.dtype(np.uint64): "uint64",
+ np.dtype(np.float16): "float16",
+ np.dtype(np.float32): "float32",
+ np.dtype(np.float64): "float64",
+ np.dtype(np.float_): "float64",
+ }
+ STR2DTYPE = {
+ "bool": {"type_code": DataTypeCode.UINT, "bits": 1, "lanes": 1},
+ "int8": {"type_code": DataTypeCode.INT, "bits": 8, "lanes": 1},
+ "int16": {"type_code": DataTypeCode.INT, "bits": 16, "lanes": 1},
+ "int32": {"type_code": DataTypeCode.INT, "bits": 32, "lanes": 1},
+ "int64": {"type_code": DataTypeCode.INT, "bits": 64, "lanes": 1},
+ "uint8": {"type_code": DataTypeCode.UINT, "bits": 8, "lanes": 1},
+ "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
+ "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
+ "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
+ "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
+ "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
+ "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
+ }
def __init__(self, type_str):
super(DataType, self).__init__()
- if isinstance(type_str, np.dtype):
+ numpy_str_map = DataType.NUMPY2STR
+ if type_str in numpy_str_map:
+ type_str = numpy_str_map[type_str]
+ elif isinstance(type_str, np.dtype):
type_str = str(type_str)
- if type_str == "bool":
- self.bits = 1
- self.type_code = DataTypeCode.UINT
- self.lanes = 1
+ assert isinstance(type_str, str)
+
+ str_dtype_map = DataType.STR2DTYPE
+ if type_str in str_dtype_map:
+ dtype_map = str_dtype_map[type_str]
+ self.bits = dtype_map["bits"]
+ self.type_code = dtype_map["type_code"]
+ self.lanes = dtype_map["lanes"]
return
arr = type_str.split("x")