This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new c51e519  [Feature] support C++ dtype_trait and Python-side mapping to 
C++ dtype (#374)
c51e519 is described below

commit c51e519b2253c2c8754bebaf2f9af0434d89e1fc
Author: DarkSharpness <[email protected]>
AuthorDate: Fri Jan 2 02:43:05 2026 +0800

    [Feature] support C++ dtype_trait and Python-side mapping to C++ dtype 
(#374)
    
    related issue #365
    
    references:
    
    1.
    https://github.com/ROCm/clr/tree/amd-staging/hipamd/include/hip/amd_detail
    2.
    
https://github.com/pytorch/pytorch/blob/f7f91ec63a6575443b2b06ded791ac7fd9a7f66d/aten/src/ATen/DLConvertor.cpp
    3.
    https://github.com/flashinfer-ai/flashinfer/blob/main/csrc/tvm_ffi_utils.h
---
 include/tvm/ffi/extra/dtype.h  | 206 +++++++++++++++++++++++++++++++++++++++++
 python/tvm_ffi/cpp/__init__.py |   2 +
 python/tvm_ffi/cpp/dtype.py    | 104 +++++++++++++++++++++
 3 files changed, 312 insertions(+)

diff --git a/include/tvm/ffi/extra/dtype.h b/include/tvm/ffi/extra/dtype.h
new file mode 100644
index 0000000..4187d0f
--- /dev/null
+++ b/include/tvm/ffi/extra/dtype.h
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/dtype.h
+ * \brief Type traits to map C++ types to DLPack dtypes.
+ */
+#ifndef TVM_FFI_EXTRA_DTYPE_H_
+#define TVM_FFI_EXTRA_DTYPE_H_
+
+#include <dlpack/dlpack.h>
+
+#include <type_traits>
+
+// Common for both CUDA and HIP
+struct __half;
+
+// CUDA
+struct __nv_fp8_e4m3;
+struct __nv_bfloat16;
+struct __nv_fp8_e5m2;
+struct __nv_fp8_e8m0;
+struct __nv_fp4_e2m1;
+struct __nv_fp4x2_e2m1;
+
+// HIP
+struct __hip_bfloat16;
+struct hip_bfloat16;  // i don't know why this is a struct instead of alias...
+struct __hip_fp8_e4m3;
+struct __hip_fp8_e4m3_fnuz;
+struct __hip_fp8_e5m2;
+struct __hip_fp8_e5m2_fnuz;
+struct __hip_fp4_e2m1;
+struct __hip_fp4x2_e2m1;
+
+namespace tvm_ffi {
+
+/// \cond Doxygen_Suppress
+
+template <typename T>
+struct dtype_trait {};
+
+namespace details::dtypes {
+
+template <typename T>
+struct integer_trait {
+  static constexpr DLDataType value = {
+      /* code = */ std::is_signed_v<T> ? kDLInt : kDLUInt,
+      /* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
+      /* lanes = */ 1,
+  };
+};
+
+template <typename T>
+struct float_trait {
+  static constexpr DLDataType value = {
+      /* code = */ kDLFloat,
+      /* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
+      /* lanes = */ 1,
+  };
+};
+
+}  // namespace details::dtypes
+
+template <>
+struct dtype_trait<signed char> : details::dtypes::integer_trait<signed char> 
{};
+
+template <>
+struct dtype_trait<unsigned char> : details::dtypes::integer_trait<unsigned 
char> {};
+
+template <>
+struct dtype_trait<signed short> : details::dtypes::integer_trait<signed 
short> {};
+
+template <>
+struct dtype_trait<unsigned short> : details::dtypes::integer_trait<unsigned 
short> {};
+
+template <>
+struct dtype_trait<signed int> : details::dtypes::integer_trait<signed int> {};
+
+template <>
+struct dtype_trait<unsigned int> : details::dtypes::integer_trait<unsigned 
int> {};
+
+template <>
+struct dtype_trait<signed long> : details::dtypes::integer_trait<signed long> 
{};
+
+template <>
+struct dtype_trait<unsigned long> : details::dtypes::integer_trait<unsigned 
long> {};
+
+template <>
+struct dtype_trait<signed long long> : details::dtypes::integer_trait<signed 
long long> {};
+
+template <>
+struct dtype_trait<unsigned long long> : 
details::dtypes::integer_trait<unsigned long long> {};
+
+template <>
+struct dtype_trait<float> : details::dtypes::float_trait<float> {};
+
+template <>
+struct dtype_trait<double> : details::dtypes::float_trait<double> {};
+
+// Specialization for bool
+
+template <>
+struct dtype_trait<bool> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLBool, 8, 1};
+};
+
+// Specializations for CUDA
+
+template <>
+struct dtype_trait<__half> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__nv_bfloat16> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e4m3> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e5m2> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e8m0> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e8m0fnu, 8, 
1};
+};
+
+template <>
+struct dtype_trait<__nv_fp4_e2m1> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp4x2_e2m1> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
+};
+
+// Specializations for HIP
+
+template <>
+struct dtype_trait<__hip_bfloat16> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<hip_bfloat16> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e4m3> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e4m3_fnuz> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fnuz, 8, 
1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e5m2> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e5m2_fnuz> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2fnuz, 8, 
1};
+};
+
+template <>
+struct dtype_trait<__hip_fp4_e2m1> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp4x2_e2m1> {
+  static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
+};
+
+/// \endcond
+
+}  // namespace tvm_ffi
+
+#endif  // TVM_FFI_EXTRA_DTYPE_H_
diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py
index e3deb12..f2d4ce2 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/python/tvm_ffi/cpp/__init__.py
@@ -16,6 +16,7 @@
 # under the License.
 """C++ integration helpers for building and loading inline modules."""
 
+from .dtype import to_cpp_dtype
 from .extension import build, build_inline, load, load_inline
 
 __all__ = [
@@ -23,4 +24,5 @@ __all__ = [
     "build_inline",
     "load",
     "load_inline",
+    "to_cpp_dtype",
 ]
diff --git a/python/tvm_ffi/cpp/dtype.py b/python/tvm_ffi/cpp/dtype.py
new file mode 100644
index 0000000..4a782b0
--- /dev/null
+++ b/python/tvm_ffi/cpp/dtype.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for C++ dtype conversion."""
+
+from __future__ import annotations
+
+import functools
+from typing import Any, Literal
+
+CPU_DTYPE_MAP = {
+    "int8": "int8_t",
+    "int16": "int16_t",
+    "int32": "int32_t",
+    "int64": "int64_t",
+    "uint8": "uint8_t",
+    "uint16": "uint16_t",
+    "uint32": "uint32_t",
+    "uint64": "uint64_t",
+    "float32": "float",
+    "float64": "double",
+    "bool": "bool",
+}
+
+CUDA_DTYPE_MAP = {
+    "float16": "__half",
+    "bfloat16": "__nv_bfloat16",
+    "float8_e4m3fn": "__nv_fp8_e4m3",
+    # "float8_e4m3fnuz": "__nv_fp8_e4m3",
+    "float8_e5m2": "__nv_fp8_e5m2",
+    # "float8_e5m2fnuz": "__nv_fp8_e5m2",
+    "float8_e8m0fnu": "__nv_fp8_e8m0",
+    "float4_e2m1": "__nv_fp4_e2m1",
+    "float4_e2m1fn_x2": "__nv_fp4x2_e2m1",
+}
+
+ROCM_DTYPE_MAP = {
+    "float16": "__half",
+    "bfloat16": "__hip_bfloat16",
+    "float8_e4m3fn": "__hip_fp8_e4m3",
+    "float8_e4m3fnuz": "__hip_fp8_e4m3_fnuz",
+    "float8_e5m2": "__hip_fp8_e5m2",
+    "float8_e5m2fnuz": "__hip_fp8_e5m2_fnuz",
+    "float4_e2m1": "__hip_fp4_e2m1",
+    "float4_e2m1fn_x2": "__hip_fp4x2_e2m1",
+}
+
+
[email protected]_cache(maxsize=None)
+def _determine_backend_once() -> Literal["cpu", "cuda", "rocm"]:
+    try:
+        import torch  # noqa: PLC0415
+
+        if torch.cuda.is_available():
+            if torch.version.cuda is not None:
+                return "cuda"
+            elif torch.version.hip is not None:
+                return "rocm"
+    except ImportError:
+        pass
+    return "cpu"
+
+
+def to_cpp_dtype(dtype_str: str | Any) -> str:
+    """Convert a dtype to its corresponding C++ dtype string.
+
+    Parameters
+    ----------
+    dtype_str : `str` or `torch.dtype`
+        The dtype string or object to convert.
+
+    Returns
+    -------
+    str
+        The corresponding C++ dtype string.
+
+    """
+    if not isinstance(dtype_str, str):
+        dtype_str = str(dtype_str)
+    if dtype_str.startswith("torch."):
+        dtype_str = dtype_str[6:]
+    cpp_str = CPU_DTYPE_MAP.get(dtype_str)
+    if cpp_str is not None:
+        return cpp_str
+    backend = _determine_backend_once()
+    if backend in ("cuda", "rocm"):
+        dtype_map = CUDA_DTYPE_MAP if backend == "cuda" else ROCM_DTYPE_MAP
+        cpp_str = dtype_map.get(dtype_str)
+        if cpp_str is not None:
+            return cpp_str
+    raise ValueError(f"Unsupported dtype string: {dtype_str} for {backend = }")

Reply via email to