This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c51e519 [Feature] support C++ dtype_trait and Python-side mapping to
C++ dtype (#374)
c51e519 is described below
commit c51e519b2253c2c8754bebaf2f9af0434d89e1fc
Author: DarkSharpness <[email protected]>
AuthorDate: Fri Jan 2 02:43:05 2026 +0800
[Feature] support C++ dtype_trait and Python-side mapping to C++ dtype
(#374)
related issue #365
references:
1.
https://github.com/ROCm/clr/tree/amd-staging/hipamd/include/hip/amd_detail
2.
https://github.com/pytorch/pytorch/blob/f7f91ec63a6575443b2b06ded791ac7fd9a7f66d/aten/src/ATen/DLConvertor.cpp
3.
https://github.com/flashinfer-ai/flashinfer/blob/main/csrc/tvm_ffi_utils.h
---
include/tvm/ffi/extra/dtype.h | 206 +++++++++++++++++++++++++++++++++++++++++
python/tvm_ffi/cpp/__init__.py | 2 +
python/tvm_ffi/cpp/dtype.py | 104 +++++++++++++++++++++
3 files changed, 312 insertions(+)
diff --git a/include/tvm/ffi/extra/dtype.h b/include/tvm/ffi/extra/dtype.h
new file mode 100644
index 0000000..4187d0f
--- /dev/null
+++ b/include/tvm/ffi/extra/dtype.h
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/dtype.h
+ * \brief Type traits to map C++ types to DLPack dtypes.
+ */
+#ifndef TVM_FFI_EXTRA_DTYPE_H_
+#define TVM_FFI_EXTRA_DTYPE_H_
+
+#include <dlpack/dlpack.h>
+
+#include <type_traits>
+
+// Common for both CUDA and HIP
+struct __half;
+
+// CUDA
+struct __nv_fp8_e4m3;
+struct __nv_bfloat16;
+struct __nv_fp8_e5m2;
+struct __nv_fp8_e8m0;
+struct __nv_fp4_e2m1;
+struct __nv_fp4x2_e2m1;
+
+// HIP
+struct __hip_bfloat16;
+struct hip_bfloat16; // i don't know why this is a struct instead of alias...
+struct __hip_fp8_e4m3;
+struct __hip_fp8_e4m3_fnuz;
+struct __hip_fp8_e5m2;
+struct __hip_fp8_e5m2_fnuz;
+struct __hip_fp4_e2m1;
+struct __hip_fp4x2_e2m1;
+
+namespace tvm_ffi {
+
+/// \cond Doxygen_Suppress
+
+template <typename T>
+struct dtype_trait {};
+
+namespace details::dtypes {
+
+template <typename T>
+struct integer_trait {
+ static constexpr DLDataType value = {
+ /* code = */ std::is_signed_v<T> ? kDLInt : kDLUInt,
+ /* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
+ /* lanes = */ 1,
+ };
+};
+
+template <typename T>
+struct float_trait {
+ static constexpr DLDataType value = {
+ /* code = */ kDLFloat,
+ /* bits = */ static_cast<uint8_t>(sizeof(T) * 8),
+ /* lanes = */ 1,
+ };
+};
+
+} // namespace details::dtypes
+
+template <>
+struct dtype_trait<signed char> : details::dtypes::integer_trait<signed char>
{};
+
+template <>
+struct dtype_trait<unsigned char> : details::dtypes::integer_trait<unsigned
char> {};
+
+template <>
+struct dtype_trait<signed short> : details::dtypes::integer_trait<signed
short> {};
+
+template <>
+struct dtype_trait<unsigned short> : details::dtypes::integer_trait<unsigned
short> {};
+
+template <>
+struct dtype_trait<signed int> : details::dtypes::integer_trait<signed int> {};
+
+template <>
+struct dtype_trait<unsigned int> : details::dtypes::integer_trait<unsigned
int> {};
+
+template <>
+struct dtype_trait<signed long> : details::dtypes::integer_trait<signed long>
{};
+
+template <>
+struct dtype_trait<unsigned long> : details::dtypes::integer_trait<unsigned
long> {};
+
+template <>
+struct dtype_trait<signed long long> : details::dtypes::integer_trait<signed
long long> {};
+
+template <>
+struct dtype_trait<unsigned long long> :
details::dtypes::integer_trait<unsigned long long> {};
+
+template <>
+struct dtype_trait<float> : details::dtypes::float_trait<float> {};
+
+template <>
+struct dtype_trait<double> : details::dtypes::float_trait<double> {};
+
+// Specialization for bool
+
+template <>
+struct dtype_trait<bool> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLBool, 8, 1};
+};
+
+// Specializations for CUDA
+
+template <>
+struct dtype_trait<__half> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__nv_bfloat16> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e4m3> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e5m2> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp8_e8m0> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e8m0fnu, 8,
1};
+};
+
+template <>
+struct dtype_trait<__nv_fp4_e2m1> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
+};
+
+template <>
+struct dtype_trait<__nv_fp4x2_e2m1> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
+};
+
+// Specializations for HIP
+
+template <>
+struct dtype_trait<__hip_bfloat16> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<hip_bfloat16> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLBfloat, 16, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e4m3> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fn, 8, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e4m3_fnuz> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e4m3fnuz, 8,
1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e5m2> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2, 8, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp8_e5m2_fnuz> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat8_e5m2fnuz, 8,
1};
+};
+
+template <>
+struct dtype_trait<__hip_fp4_e2m1> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 1};
+};
+
+template <>
+struct dtype_trait<__hip_fp4x2_e2m1> {
+ static constexpr DLDataType value = {DLDataTypeCode::kDLFloat4_e2m1fn, 4, 2};
+};
+
+/// \endcond
+
+} // namespace tvm_ffi
+
+#endif // TVM_FFI_EXTRA_DTYPE_H_
diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py
index e3deb12..f2d4ce2 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/python/tvm_ffi/cpp/__init__.py
@@ -16,6 +16,7 @@
# under the License.
"""C++ integration helpers for building and loading inline modules."""
+from .dtype import to_cpp_dtype
from .extension import build, build_inline, load, load_inline
__all__ = [
@@ -23,4 +24,5 @@ __all__ = [
"build_inline",
"load",
"load_inline",
+ "to_cpp_dtype",
]
diff --git a/python/tvm_ffi/cpp/dtype.py b/python/tvm_ffi/cpp/dtype.py
new file mode 100644
index 0000000..4a782b0
--- /dev/null
+++ b/python/tvm_ffi/cpp/dtype.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for C++ dtype conversion."""
+
+from __future__ import annotations
+
+import functools
+from typing import Any, Literal
+
+CPU_DTYPE_MAP = {
+ "int8": "int8_t",
+ "int16": "int16_t",
+ "int32": "int32_t",
+ "int64": "int64_t",
+ "uint8": "uint8_t",
+ "uint16": "uint16_t",
+ "uint32": "uint32_t",
+ "uint64": "uint64_t",
+ "float32": "float",
+ "float64": "double",
+ "bool": "bool",
+}
+
+CUDA_DTYPE_MAP = {
+ "float16": "__half",
+ "bfloat16": "__nv_bfloat16",
+ "float8_e4m3fn": "__nv_fp8_e4m3",
+ # "float8_e4m3fnuz": "__nv_fp8_e4m3",
+ "float8_e5m2": "__nv_fp8_e5m2",
+ # "float8_e5m2fnuz": "__nv_fp8_e5m2",
+ "float8_e8m0fnu": "__nv_fp8_e8m0",
+ "float4_e2m1": "__nv_fp4_e2m1",
+ "float4_e2m1fn_x2": "__nv_fp4x2_e2m1",
+}
+
+ROCM_DTYPE_MAP = {
+ "float16": "__half",
+ "bfloat16": "__hip_bfloat16",
+ "float8_e4m3fn": "__hip_fp8_e4m3",
+ "float8_e4m3fnuz": "__hip_fp8_e4m3_fnuz",
+ "float8_e5m2": "__hip_fp8_e5m2",
+ "float8_e5m2fnuz": "__hip_fp8_e5m2_fnuz",
+ "float4_e2m1": "__hip_fp4_e2m1",
+ "float4_e2m1fn_x2": "__hip_fp4x2_e2m1",
+}
+
+
[email protected]_cache(maxsize=None)
+def _determine_backend_once() -> Literal["cpu", "cuda", "rocm"]:
+ try:
+ import torch # noqa: PLC0415
+
+ if torch.cuda.is_available():
+ if torch.version.cuda is not None:
+ return "cuda"
+ elif torch.version.hip is not None:
+ return "rocm"
+ except ImportError:
+ pass
+ return "cpu"
+
+
+def to_cpp_dtype(dtype_str: str | Any) -> str:
+ """Convert a dtype to its corresponding C++ dtype string.
+
+ Parameters
+ ----------
+ dtype_str : `str` or `torch.dtype`
+ The dtype string or object to convert.
+
+ Returns
+ -------
+ str
+ The corresponding C++ dtype string.
+
+ """
+ if not isinstance(dtype_str, str):
+ dtype_str = str(dtype_str)
+ if dtype_str.startswith("torch."):
+ dtype_str = dtype_str[6:]
+ cpp_str = CPU_DTYPE_MAP.get(dtype_str)
+ if cpp_str is not None:
+ return cpp_str
+ backend = _determine_backend_once()
+ if backend in ("cuda", "rocm"):
+ dtype_map = CUDA_DTYPE_MAP if backend == "cuda" else ROCM_DTYPE_MAP
+ cpp_str = dtype_map.get(dtype_str)
+ if cpp_str is not None:
+ return cpp_str
+ raise ValueError(f"Unsupported dtype string: {dtype_str} for {backend = }")