This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 408aa78 [DTYPE] Include dtype literals (#263)
408aa78 is described below
commit 408aa78c4e7036127238ca626fdcb23e11103527
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 14 19:24:42 2025 -0500
[DTYPE] Include dtype literals (#263)
This PR enhances the codebase to include dtype literals.
---
python/tvm_ffi/__init__.py | 22 ++++++++++++++++++++++
python/tvm_ffi/_convert.py | 4 +++-
python/tvm_ffi/_dtype.py | 28 ++++++++++++++++++++++++++++
tests/python/test_dtype.py | 22 ++++++++++++++++++++++
4 files changed, 75 insertions(+), 1 deletion(-)
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 22bdd52..95ca145 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -60,6 +60,28 @@ from . import cpp
# optional module to speedup dlpack conversion
from . import _optional_torch_c_dlpack
+# import the dtype literals
+from ._dtype import (
+ bool,
+ int8,
+ int16,
+ int32,
+ int64,
+ uint8,
+ uint16,
+ uint32,
+ uint64,
+ float64,
+ float32,
+ float16,
+ bfloat16,
+ float8_e4m3fn,
+ float8_e4m3fnuz,
+ float8_e5m2,
+ float8_e5m2fnuz,
+ float8_e8m0fnu,
+ float4_e2m1fnx2,
+)
try:
from ._version import __version__, __version_tuple__ # type:
ignore[import-not-found]
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index f090d28..59783f6 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -91,7 +91,9 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
only used in internal or testing scenarios.
"""
- if isinstance(value, (core.Object, core.PyNativeObject, bool, Number,
ctypes.c_void_p)):
+ if isinstance(
+ value, (core.Object, core.PyNativeObject, bool, Number,
ctypes.c_void_p, _dtype.dtype)
+ ):
return value
elif isinstance(value, (tuple, list)):
return container.Array(value)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 216af55..65ecad2 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -321,3 +321,31 @@ except ImportError:
pass
core._set_class_dtype(dtype)
+
+# list of common dtype literals in machine learning systems apps
+# note that we can always cover more dtypes via explicit construction
+# from dlpack data type tuple
+# align with choice of numpy 2.0, which moved away from bool_ to bool
+bool = dtype("bool")
+int8 = dtype("int8")
+int16 = dtype("int16")
+int32 = dtype("int32")
+int64 = dtype("int64")
+uint8 = dtype("uint8")
+uint16 = dtype("uint16")
+uint32 = dtype("uint32")
+uint64 = dtype("uint64")
+float64 = dtype("float64")
+float32 = dtype("float32")
+float16 = dtype("float16")
+bfloat16 = dtype("bfloat16")
+# float8 dtypes
+float8_e4m3fn = dtype("float8_e4m3fn")
+float8_e4m3fnuz = dtype("float8_e4m3fnuz")
+float8_e5m2 = dtype("float8_e5m2")
+float8_e5m2fnuz = dtype("float8_e5m2fnuz")
+float8_e8m0fnu = dtype("float8_e8m0fnu")
+# float4x2 dtypes
+float4_e2m1fnx2 = dtype("float4_e2m1fnx2")
+# alias for torch naming pattern for f4x2
+float4_e2m1fn_x2 = float4_e2m1fnx2
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 2897faa..6817717 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -164,6 +164,28 @@ def test_ml_dtypes_dtype_conversion() -> None:
_check_dtype(np.dtype(ml_dtypes.float4_e2m1fn), 17, 4, 1)
+def test_builtin_dtype_conversion() -> None:
+ _check_dtype(tvm_ffi.bool, 6, 8, 1)
+ _check_dtype(tvm_ffi.int8, 0, 8, 1)
+ _check_dtype(tvm_ffi.int16, 0, 16, 1)
+ _check_dtype(tvm_ffi.int32, 0, 32, 1)
+ _check_dtype(tvm_ffi.int64, 0, 64, 1)
+ _check_dtype(tvm_ffi.uint8, 1, 8, 1)
+ _check_dtype(tvm_ffi.uint16, 1, 16, 1)
+ _check_dtype(tvm_ffi.uint32, 1, 32, 1)
+ _check_dtype(tvm_ffi.uint64, 1, 64, 1)
+ _check_dtype(tvm_ffi.float16, 2, 16, 1)
+ _check_dtype(tvm_ffi.float32, 2, 32, 1)
+ _check_dtype(tvm_ffi.float64, 2, 64, 1)
+ _check_dtype(tvm_ffi.bfloat16, 4, 16, 1)
+ _check_dtype(tvm_ffi.float8_e4m3fn, 10, 8, 1)
+ _check_dtype(tvm_ffi.float8_e4m3fnuz, 11, 8, 1)
+ _check_dtype(tvm_ffi.float8_e5m2, 12, 8, 1)
+ _check_dtype(tvm_ffi.float8_e5m2fnuz, 13, 8, 1)
+ _check_dtype(tvm_ffi.float8_e8m0fnu, 14, 8, 1)
+ _check_dtype(tvm_ffi.float4_e2m1fnx2, 17, 4, 2)
+
+
def test_dtype_from_dlpack_data_type() -> None:
dtype = tvm_ffi.dtype.from_dlpack_data_type((0, 8, 1))
assert dtype.type_code == 0