This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 408aa78  [DTYPE] Include dtype literals (#263)
408aa78 is described below

commit 408aa78c4e7036127238ca626fdcb23e11103527
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 14 19:24:42 2025 -0500

    [DTYPE] Include dtype literals (#263)
    
    This PR enhances the codebase to include dtype literals.
---
 python/tvm_ffi/__init__.py | 22 ++++++++++++++++++++++
 python/tvm_ffi/_convert.py |  4 +++-
 python/tvm_ffi/_dtype.py   | 28 ++++++++++++++++++++++++++++
 tests/python/test_dtype.py | 22 ++++++++++++++++++++++
 4 files changed, 75 insertions(+), 1 deletion(-)

diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 22bdd52..95ca145 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -60,6 +60,28 @@ from . import cpp
 # optional module to speedup dlpack conversion
 from . import _optional_torch_c_dlpack
 
+# import the dtype literals
+from ._dtype import (
+    bool,
+    int8,
+    int16,
+    int32,
+    int64,
+    uint8,
+    uint16,
+    uint32,
+    uint64,
+    float64,
+    float32,
+    float16,
+    bfloat16,
+    float8_e4m3fn,
+    float8_e4m3fnuz,
+    float8_e5m2,
+    float8_e5m2fnuz,
+    float8_e8m0fnu,
+    float4_e2m1fnx2,
+)
 
 try:
     from ._version import __version__, __version_tuple__  # type: 
ignore[import-not-found]
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index f090d28..59783f6 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -91,7 +91,9 @@ def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
     only used in internal or testing scenarios.
 
     """
-    if isinstance(value, (core.Object, core.PyNativeObject, bool, Number, 
ctypes.c_void_p)):
+    if isinstance(
+        value, (core.Object, core.PyNativeObject, bool, Number, 
ctypes.c_void_p, _dtype.dtype)
+    ):
         return value
     elif isinstance(value, (tuple, list)):
         return container.Array(value)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 216af55..65ecad2 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -321,3 +321,31 @@ except ImportError:
     pass
 
 core._set_class_dtype(dtype)
+
+# list of common dtype literals in machine learning systems apps
+# note that we can always cover more dtypes via explicit construction
+# from dlpack data type tuple
+# align with choice of numpy 2.0, which moved away from bool_ to bool
+bool = dtype("bool")
+int8 = dtype("int8")
+int16 = dtype("int16")
+int32 = dtype("int32")
+int64 = dtype("int64")
+uint8 = dtype("uint8")
+uint16 = dtype("uint16")
+uint32 = dtype("uint32")
+uint64 = dtype("uint64")
+float64 = dtype("float64")
+float32 = dtype("float32")
+float16 = dtype("float16")
+bfloat16 = dtype("bfloat16")
+# float8 dtypes
+float8_e4m3fn = dtype("float8_e4m3fn")
+float8_e4m3fnuz = dtype("float8_e4m3fnuz")
+float8_e5m2 = dtype("float8_e5m2")
+float8_e5m2fnuz = dtype("float8_e5m2fnuz")
+float8_e8m0fnu = dtype("float8_e8m0fnu")
+# float4x2 dtypes
+float4_e2m1fnx2 = dtype("float4_e2m1fnx2")
+# alias for torch naming pattern for f4x2
+float4_e2m1fn_x2 = float4_e2m1fnx2
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 2897faa..6817717 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -164,6 +164,28 @@ def test_ml_dtypes_dtype_conversion() -> None:
     _check_dtype(np.dtype(ml_dtypes.float4_e2m1fn), 17, 4, 1)
 
 
+def test_builtin_dtype_conversion() -> None:
+    _check_dtype(tvm_ffi.bool, 6, 8, 1)
+    _check_dtype(tvm_ffi.int8, 0, 8, 1)
+    _check_dtype(tvm_ffi.int16, 0, 16, 1)
+    _check_dtype(tvm_ffi.int32, 0, 32, 1)
+    _check_dtype(tvm_ffi.int64, 0, 64, 1)
+    _check_dtype(tvm_ffi.uint8, 1, 8, 1)
+    _check_dtype(tvm_ffi.uint16, 1, 16, 1)
+    _check_dtype(tvm_ffi.uint32, 1, 32, 1)
+    _check_dtype(tvm_ffi.uint64, 1, 64, 1)
+    _check_dtype(tvm_ffi.float16, 2, 16, 1)
+    _check_dtype(tvm_ffi.float32, 2, 32, 1)
+    _check_dtype(tvm_ffi.float64, 2, 64, 1)
+    _check_dtype(tvm_ffi.bfloat16, 4, 16, 1)
+    _check_dtype(tvm_ffi.float8_e4m3fn, 10, 8, 1)
+    _check_dtype(tvm_ffi.float8_e4m3fnuz, 11, 8, 1)
+    _check_dtype(tvm_ffi.float8_e5m2, 12, 8, 1)
+    _check_dtype(tvm_ffi.float8_e5m2fnuz, 13, 8, 1)
+    _check_dtype(tvm_ffi.float8_e8m0fnu, 14, 8, 1)
+    _check_dtype(tvm_ffi.float4_e2m1fnx2, 17, 4, 2)
+
+
 def test_dtype_from_dlpack_data_type() -> None:
     dtype = tvm_ffi.dtype.from_dlpack_data_type((0, 8, 1))
     assert dtype.type_code == 0

Reply via email to