This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 287ee40  [PYTHON][FFI] Speed Up get DataType (#9072)
287ee40 is described below

commit 287ee40adc7e0d98482a3b71784ccfad732a980e
Author: wangxiang2713 <[email protected]>
AuthorDate: Thu Sep 23 03:19:22 2021 +0800

    [PYTHON][FFI] Speed Up get DataType (#9072)
---
 python/tvm/_ffi/runtime_ctypes.py | 46 ++++++++++++++++++++++++++++++++++-----
 1 file changed, 41 insertions(+), 5 deletions(-)

diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 450a356..297e24d 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -72,16 +72,52 @@ class DataType(ctypes.Structure):
         DataTypeCode.HANDLE: "handle",
         DataTypeCode.BFLOAT: "bfloat",
     }
+    NUMPY2STR = {
+        np.dtype(np.bool_): "bool",
+        np.dtype(np.int8): "int8",
+        np.dtype(np.int16): "int16",
+        np.dtype(np.int32): "int32",
+        np.dtype(np.int64): "int64",
+        np.dtype(np.uint8): "uint8",
+        np.dtype(np.uint16): "uint16",
+        np.dtype(np.uint32): "uint32",
+        np.dtype(np.uint64): "uint64",
+        np.dtype(np.float16): "float16",
+        np.dtype(np.float32): "float32",
+        np.dtype(np.float64): "float64",
+        np.dtype(np.float_): "float64",
+    }
+    STR2DTYPE = {
+        "bool": {"type_code": DataTypeCode.UINT, "bits": 1, "lanes": 1},
+        "int8": {"type_code": DataTypeCode.INT, "bits": 8, "lanes": 1},
+        "int16": {"type_code": DataTypeCode.INT, "bits": 16, "lanes": 1},
+        "int32": {"type_code": DataTypeCode.INT, "bits": 32, "lanes": 1},
+        "int64": {"type_code": DataTypeCode.INT, "bits": 64, "lanes": 1},
+        "uint8": {"type_code": DataTypeCode.UINT, "bits": 8, "lanes": 1},
+        "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1},
+        "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1},
+        "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
+        "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
+        "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
+        "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
+    }
 
     def __init__(self, type_str):
         super(DataType, self).__init__()
-        if isinstance(type_str, np.dtype):
+        numpy_str_map = DataType.NUMPY2STR
+        if type_str in numpy_str_map:
+            type_str = numpy_str_map[type_str]
+        elif isinstance(type_str, np.dtype):
             type_str = str(type_str)
 
-        if type_str == "bool":
-            self.bits = 1
-            self.type_code = DataTypeCode.UINT
-            self.lanes = 1
+        assert isinstance(type_str, str)
+
+        str_dtype_map = DataType.STR2DTYPE
+        if type_str in str_dtype_map:
+            dtype_map = str_dtype_map[type_str]
+            self.bits = dtype_map["bits"]
+            self.type_code = dtype_map["type_code"]
+            self.lanes = dtype_map["lanes"]
             return
 
         arr = type_str.split("x")

Reply via email to