This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a87749  [CYTHON] Improve fallback and dtype convert behavior (#241)
5a87749 is described below

commit 5a8774943a668f4344b0afbffa28ec5f826c8b15
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 7 22:22:15 2025 -0500

    [CYTHON] Improve fallback and dtype convert behavior (#241)
    
    This PR improves the torch fallback behavior and fix the dtype
    conversion to ensure tvm_ffi.convert always return right tvm_ffi.dtype
    when passing in dtype
---
 addons/torch_c_dlpack_ext/README.md        | 11 ++++--
 docs/index.rst                             |  7 ++++
 python/tvm_ffi/_convert.py                 | 19 ++++++++-
 python/tvm_ffi/_dtype.py                   |  4 +-
 python/tvm_ffi/_optional_torch_c_dlpack.py |  5 ++-
 python/tvm_ffi/core.pyi                    |  4 +-
 python/tvm_ffi/cython/dtype.pxi            | 63 ++++++++++++++++--------------
 python/tvm_ffi/cython/function.pxi         | 19 +++++----
 tests/python/test_dtype.py                 |  2 +
 tests/python/test_function.py              |  2 +
 10 files changed, 87 insertions(+), 49 deletions(-)

diff --git a/addons/torch_c_dlpack_ext/README.md 
b/addons/torch_c_dlpack_ext/README.md
index 5e3b8d8..a3ac58c 100644
--- a/addons/torch_c_dlpack_ext/README.md
+++ b/addons/torch_c_dlpack_ext/README.md
@@ -16,10 +16,13 @@
 <!--- under the License. -->
 # Torch C DLPack Extension
 
-This folder contains the source for the `torch-c-dlpack-ext` package, which 
provides an Ahead-Of-Time (AOT) compiled module to support faster DLPack 
conversion.
-This module will likely become unnecessary once PyTorch releases include 
DLPack v1.2 support.
-By default, `tvm-ffi` will JIT-compile a version of this functionality during 
loading.
-Installing this wheel allows users to avoid this JIT compilation overhead.
+This folder contains the source for the `torch-c-dlpack-ext` package, which 
provides an
+Ahead-Of-Time (AOT) compiled module to support faster DLPack conversion in 
DLPack v1.2.
+By default, `tvm-ffi` will JIT-compile a version of this functionality during 
loading,
+and use a safe-fallback if JIT-compilation fails.
+Installing this wheel allows users to avoid this JIT compilation overhead and
+also avoid the cases where the user environment does not necessarily have a 
compiler toolchain
+to run JIT-compilation.
 
 ```bash
 pip install torch-c-dlpack-ext
diff --git a/docs/index.rst b/docs/index.rst
index fd859a7..5b137d5 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -31,6 +31,13 @@ To install via pip, run:
 
    pip install apache-tvm-ffi
 
+We also recommend installing the optional package below for improved
+torch tensor conversion performance.
+
+.. code-block:: bash
+
+   pip install torch-c-dlpack-ext
+
 
 Table of Contents
 -----------------
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 43a564c..d0a44e3 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -18,11 +18,12 @@
 
 from __future__ import annotations
 
+import ctypes
 from numbers import Number
 from types import ModuleType
 from typing import Any
 
-from . import container, core
+from . import _dtype, container, core
 
 torch: ModuleType | None = None
 try:
@@ -90,7 +91,7 @@ def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
     only used in internal or testing scenarios.
 
     """
-    if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)):
+    if isinstance(value, (core.Object, core.PyNativeObject, bool, Number, 
ctypes.c_void_p)):
         return value
     elif isinstance(value, (tuple, list)):
         return container.Array(value)
@@ -112,8 +113,22 @@ def convert(value: Any) -> Any:  # noqa: PLR0911,PLR0912
         return core._convert_torch_dtype_to_ffi_dtype(value)
     elif numpy is not None and isinstance(value, numpy.dtype):
         return core._convert_numpy_dtype_to_ffi_dtype(value)
+    elif hasattr(value, "__dlpack_data_type__"):
+        cdtype = core._create_cdtype_from_tuple(core.DataType, 
*value.__dlpack_data_type__())
+        dtype = str.__new__(_dtype.dtype, str(cdtype))
+        dtype._tvm_ffi_dtype = cdtype
+        return dtype
     elif isinstance(value, Exception):
         return core._convert_to_ffi_error(value)
+    elif hasattr(value, "__tvm_ffi_object__"):
+        return value.__tvm_ffi_object__()
+    # keep rest protocol values as it is as they can be handled by ffi function
+    elif hasattr(value, "__cuda_stream__"):
+        return value
+    elif hasattr(value, "__tvm_ffi_opaque_ptr__"):
+        return value
+    elif hasattr(value, "__dlpack_device__"):
+        return value
     else:
         # in this case, it is an opaque python object
         return core._convert_to_opaque_object(value)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index ebc2585..7fa78d5 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -129,7 +129,7 @@ class dtype(str):
             Create vector dtypes from a scalar base.
 
         """
-        cdtype = core._create_dtype_from_tuple(
+        cdtype = core._create_cdtype_from_tuple(
             core.DataType,
             dltype_data_type[0],
             dltype_data_type[1],
@@ -172,7 +172,7 @@ class dtype(str):
             Construct from a DLPack ``(code, bits, lanes)`` triple.
 
         """
-        cdtype = core._create_dtype_from_tuple(
+        cdtype = core._create_cdtype_from_tuple(
             core.DataType,
             self._tvm_ffi_dtype.type_code,
             self._tvm_ffi_dtype.bits,
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 7673c03..29c1db3 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -102,9 +102,10 @@ def load_torch_c_dlpack_extension() -> Any:
         return lib
     except ImportError:
         pass
-    except Exception as e:
+    except Exception:
         warnings.warn(
-            f"Failed to load torch c dlpack extension, EnvTensorAllocator will 
not be enabled:\n  {e}"
+            "Failed to JIT torch c dlpack extension, EnvTensorAllocator will 
not be enabled.\n"
+            "You may try AOT-module via `pip install torch-c-dlpack-ext`"
         )
     return None
 
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 853af72..d5d43b7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -103,7 +103,9 @@ class DataType:
 def _set_class_dtype(cls: type) -> None: ...
 def _convert_torch_dtype_to_ffi_dtype(torch_dtype: Any) -> DataType: ...
 def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) -> DataType: ...
-def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes: 
int) -> DataType: ...
+def _create_cdtype_from_tuple(
+    cls: type[DataType], code: int, bits: int, lanes: int
+) -> DataType: ...
 
 class DLDeviceType(IntEnum):
     kDLCPU = 1
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 15f9418..3f3e460 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -51,12 +51,12 @@ def _set_class_dtype(cls):
     _CLASS_DTYPE = cls
 
 
-def _create_dtype_from_tuple(cls, code, bits, lanes):
+def _create_cdtype_from_tuple(cls, code, bits, lanes):
     cdef DLDataType cdtype
     cdtype.code = code
     cdtype.bits = bits
     cdtype.lanes = lanes
-    ret = cls.__new__(cls, str(cdtype))
+    ret = cls.__new__(cls)
     (<DataType>ret).cdtype = cdtype
     return ret
 
@@ -87,7 +87,7 @@ cdef class DataType:
 
     def __reduce__(self) -> Any:
         cls = type(self)
-        return (_create_dtype_from_tuple,
+        return (_create_cdtype_from_tuple,
                 (cls, self.cdtype.code, self.cdtype.bits, self.cdtype.lanes))
 
     def __eq__(self, other: object) -> bool:
@@ -102,6 +102,9 @@ cdef class DataType:
     def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
+    def __hash__(self) -> int:
+        return hash((self.cdtype.code, self.cdtype.bits, self.cdtype.lanes))
+
     @property
     def type_code(self) -> int:
         """Integer DLDataTypeCode of the scalar base type."""
@@ -151,20 +154,24 @@ cdef class DataType:
         return res
 
 
-cdef inline object make_ret_dtype(TVMFFIAny result):
+cdef inline object make_dtype_from_dl_data_type(DLDataType dl_data_type):
     cdtype = DataType.__new__(DataType)
-    (<DataType>cdtype).cdtype = result.v_dtype
+    (<DataType>cdtype).cdtype = dl_data_type
     val = str.__new__(_CLASS_DTYPE, cdtype.__str__())
     val._tvm_ffi_dtype = cdtype
     return val
 
 
-cdef TORCH_DTYPE_TO_DTYPE = {}
-cdef NUMPY_DTYPE_TO_DTYPE = {}
-cdef MLDTYPES_DTYPE_TO_DTYPE = {}
+cdef inline object make_ret_dtype(TVMFFIAny result):
+    return make_dtype_from_dl_data_type(result.v_dtype)
+
+
+cdef TORCH_DTYPE_TO_DL_DATA_TYPE = {}
+cdef NUMPY_DTYPE_TO_DL_DATA_TYPE = {}
+cdef MLDTYPES_DTYPE_TO_DL_DATA_TYPE = {}
 
 if torch is not None:
-    TORCH_DTYPE_TO_DTYPE = {
+    TORCH_DTYPE_TO_DL_DATA_TYPE = {
         torch.int8: DLDataType(0, 8, 1),
         torch.short: DLDataType(0, 16, 1),
         torch.int16: DLDataType(0, 16, 1),
@@ -190,21 +197,19 @@ if torch is not None:
         torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
     }
     if hasattr(torch, "float8_e8m0fnu"):
-        TORCH_DTYPE_TO_DTYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 1)
+        TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 
1)
     if hasattr(torch, "float4_e2m1fn_x2"):
-        TORCH_DTYPE_TO_DTYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 4, 2)
+        TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 
4, 2)
 
     def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
-        cdef DLDataType cdtype = TORCH_DTYPE_TO_DTYPE[torch_dtype]
-        ret = DataType.__new__(DataType, str(cdtype))
-        (<DataType>ret).cdtype = cdtype
-        return ret
+        cdef DLDataType dl_data_type = TORCH_DTYPE_TO_DL_DATA_TYPE[torch_dtype]
+        return make_dtype_from_dl_data_type(dl_data_type)
 else:
     def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
         raise ValueError("torch not found")
 
 if ml_dtypes is not None:
-    MLDTYPES_DTYPE_TO_DTYPE = {
+    MLDTYPES_DTYPE_TO_DL_DATA_TYPE = {
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
@@ -216,19 +221,19 @@ if ml_dtypes is not None:
     }
 
     if hasattr(ml_dtypes, "int2"):  # ml_dtypes >= 0.5.0
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0, 
2, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = DLDataType(1, 
2, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.int2)] = 
DLDataType(0, 2, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.uint2)] = 
DLDataType(1, 2, 1)
 
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] = 
DLDataType(7, 8, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] = 
DLDataType(8, 8, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] = 
DLDataType(14, 8, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] = 
DLDataType(15, 6, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] = 
DLDataType(16, 6, 1)
-        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] = 
DLDataType(17, 4, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e3m4)] = 
DLDataType(7, 8, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e4m3)] = 
DLDataType(8, 8, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] 
= DLDataType(14, 8, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] = 
DLDataType(15, 6, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] = 
DLDataType(16, 6, 1)
+        MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] = 
DLDataType(17, 4, 1)
 
 
 if numpy is not None:
-    NUMPY_DTYPE_TO_DTYPE = {
+    NUMPY_DTYPE_TO_DL_DATA_TYPE = {
         numpy.dtype(numpy.int8): DLDataType(0, 8, 1),
         numpy.dtype(numpy.int16): DLDataType(0, 16, 1),
         numpy.dtype(numpy.int32): DLDataType(0, 32, 1),
@@ -240,14 +245,12 @@ if numpy is not None:
         numpy.dtype(numpy.float16): DLDataType(2, 16, 1),
         numpy.dtype(numpy.float32): DLDataType(2, 32, 1),
         numpy.dtype(numpy.float64): DLDataType(2, 64, 1),
-        **MLDTYPES_DTYPE_TO_DTYPE,
+        **MLDTYPES_DTYPE_TO_DL_DATA_TYPE,
     }
 
     def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype):
-        cdef DLDataType cdtype = NUMPY_DTYPE_TO_DTYPE[numpy_dtype]
-        ret = DataType.__new__(DataType, str(cdtype))
-        (<DataType>ret).cdtype = cdtype
-        return ret
+        cdef DLDataType cdtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[numpy_dtype]
+        return make_dtype_from_dl_data_type(cdtype)
 else:
     def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
         raise ValueError("numpy not found")
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index acfe3e2..3ebe862 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -188,7 +188,7 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
     return 0
 
 
-cdef int TorchDLPackToPyObjectFallback_(
+cdef int TorchManagedTensorToPyObjectNoSyncFallback_(
     DLManagedTensorVersioned* dltensor, void** py_obj_out
 ) except -1:
     # a bit convoluted but ok as a fallback
@@ -211,7 +211,9 @@ cdef inline const DLPackExchangeAPI* 
GetTorchFallbackExchangeAPI() noexcept:
     _torch_fallback_exchange_api.header.prev_api = NULL
     _torch_fallback_exchange_api.managed_tensor_allocator = NULL
     _torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL
-    _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = 
TorchDLPackToPyObjectFallback_
+    _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = (
+        TorchManagedTensorToPyObjectNoSyncFallback_
+    )
     _torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL
     _torch_fallback_exchange_api.current_work_stream = NULL
 
@@ -228,6 +230,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
     """Current setter for torch.Tensor, go through python and not as fast as c 
exporter"""
     # TODO(tqchen): remove this once torch always support fast DLPack importer
     cdef object arg = <object>py_arg
+    cdef long long temp_ptr
     is_cuda = arg.is_cuda
     arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg))
     out.type_index = kTVMFFITensor
@@ -235,7 +238,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
     temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
     ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
     # record the stream and device for torch context
-    if is_cuda and ctx.device_type != -1:
+    if is_cuda and ctx.device_type == -1:
         ctx.device_type = temp_dltensor.device.device_type
         ctx.device_id = temp_dltensor.device.device_id
         # This is an API that dynamo and other uses to get the raw stream from 
torch
@@ -587,10 +590,10 @@ cdef int TVMFFIPyArgSetterDTypeFromTorch_(
 ) except -1:
     """Setter for torch dtype"""
     cdef py_obj = <object>py_arg
-    if py_obj not in TORCH_DTYPE_TO_DTYPE:
+    if py_obj not in TORCH_DTYPE_TO_DL_DATA_TYPE:
         raise ValueError("Unsupported torch dtype: ", py_obj)
     out.type_index = kTVMFFIDataType
-    out.v_dtype = TORCH_DTYPE_TO_DTYPE[py_obj]
+    out.v_dtype = TORCH_DTYPE_TO_DL_DATA_TYPE[py_obj]
     return 0
 
 cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
@@ -599,10 +602,10 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
 ) except -1:
     """Setter for torch dtype"""
     cdef py_obj = <object>py_arg
-    if py_obj not in NUMPY_DTYPE_TO_DTYPE:
+    if py_obj not in NUMPY_DTYPE_TO_DL_DATA_TYPE:
         raise ValueError("Unsupported numpy or ml_dtypes dtype: ", py_obj)
     out.type_index = kTVMFFIDataType
-    out.v_dtype = NUMPY_DTYPE_TO_DTYPE[py_obj]
+    out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
     return 0
 
 cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
@@ -667,7 +670,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         # as a member variable
         out.func = TVMFFIPyArgSetterFFIObjectCompatible_
         return 0
-    if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
+    if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
         # Check for DLPackExchangeAPI struct (new approach)
         # This is checked on the CLASS, not the instance
         if hasattr(arg_class, "__c_dlpack_exchange_api__"):
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 60d53ef..cae5f84 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -88,10 +88,12 @@ _fecho = tvm_ffi.get_global_func("testing.echo")
 
 def _check_dtype(dtype: Any, code: int, bits: int, lanes: int) -> None:
     echo_dtype = _fecho(dtype)
+    assert isinstance(echo_dtype, tvm_ffi.dtype)
     assert echo_dtype.type_code == code
     assert echo_dtype.bits == bits
     assert echo_dtype.lanes == lanes
     converted_dtype = tvm_ffi.convert(dtype)
+    assert isinstance(converted_dtype, tvm_ffi.dtype)
     assert converted_dtype.type_code == code
     assert converted_dtype.bits == bits
     assert converted_dtype.lanes == lanes
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index deef0f5..a1e68b0 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -332,6 +332,8 @@ def test_function_with_dlpack_data_type_protocol() -> None:
     x = DLPackDataTypeProtocol((dtype.type_code, dtype.bits, dtype.lanes))
     y = fecho(x)
     assert y == dtype
+    converted_y = tvm_ffi.convert(x)
+    assert converted_y == dtype
 
 
 def test_function_with_dlpack_device_protocol() -> None:

Reply via email to