This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 5a87749 [CYTHON] Improve fallback and dtype convert behavior (#241)
5a87749 is described below
commit 5a8774943a668f4344b0afbffa28ec5f826c8b15
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 7 22:22:15 2025 -0500
[CYTHON] Improve fallback and dtype convert behavior (#241)
This PR improves the torch fallback behavior and fix the dtype
conversion to ensure tvm_ffi.convert always return right tvm_ffi.dtype
when passing in dtype
---
addons/torch_c_dlpack_ext/README.md | 11 ++++--
docs/index.rst | 7 ++++
python/tvm_ffi/_convert.py | 19 ++++++++-
python/tvm_ffi/_dtype.py | 4 +-
python/tvm_ffi/_optional_torch_c_dlpack.py | 5 ++-
python/tvm_ffi/core.pyi | 4 +-
python/tvm_ffi/cython/dtype.pxi | 63 ++++++++++++++++--------------
python/tvm_ffi/cython/function.pxi | 19 +++++----
tests/python/test_dtype.py | 2 +
tests/python/test_function.py | 2 +
10 files changed, 87 insertions(+), 49 deletions(-)
diff --git a/addons/torch_c_dlpack_ext/README.md
b/addons/torch_c_dlpack_ext/README.md
index 5e3b8d8..a3ac58c 100644
--- a/addons/torch_c_dlpack_ext/README.md
+++ b/addons/torch_c_dlpack_ext/README.md
@@ -16,10 +16,13 @@
<!--- under the License. -->
# Torch C DLPack Extension
-This folder contains the source for the `torch-c-dlpack-ext` package, which
provides an Ahead-Of-Time (AOT) compiled module to support faster DLPack
conversion.
-This module will likely become unnecessary once PyTorch releases include
DLPack v1.2 support.
-By default, `tvm-ffi` will JIT-compile a version of this functionality during
loading.
-Installing this wheel allows users to avoid this JIT compilation overhead.
+This folder contains the source for the `torch-c-dlpack-ext` package, which
provides an
+Ahead-Of-Time (AOT) compiled module to support faster DLPack conversion in
DLPack v1.2.
+By default, `tvm-ffi` will JIT-compile a version of this functionality during
loading,
+and use a safe-fallback if JIT-compilation fails.
+Installing this wheel allows users to avoid this JIT compilation overhead and
+also avoid the cases where the user environment does not necessarily have a
compiler toolchain
+to run JIT-compilation.
```bash
pip install torch-c-dlpack-ext
diff --git a/docs/index.rst b/docs/index.rst
index fd859a7..5b137d5 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -31,6 +31,13 @@ To install via pip, run:
pip install apache-tvm-ffi
+We also recommend installing the optional package below for improved
+torch tensor conversion performance.
+
+.. code-block:: bash
+
+ pip install torch-c-dlpack-ext
+
Table of Contents
-----------------
diff --git a/python/tvm_ffi/_convert.py b/python/tvm_ffi/_convert.py
index 43a564c..d0a44e3 100644
--- a/python/tvm_ffi/_convert.py
+++ b/python/tvm_ffi/_convert.py
@@ -18,11 +18,12 @@
from __future__ import annotations
+import ctypes
from numbers import Number
from types import ModuleType
from typing import Any
-from . import container, core
+from . import _dtype, container, core
torch: ModuleType | None = None
try:
@@ -90,7 +91,7 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
only used in internal or testing scenarios.
"""
- if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)):
+ if isinstance(value, (core.Object, core.PyNativeObject, bool, Number,
ctypes.c_void_p)):
return value
elif isinstance(value, (tuple, list)):
return container.Array(value)
@@ -112,8 +113,22 @@ def convert(value: Any) -> Any: # noqa: PLR0911,PLR0912
return core._convert_torch_dtype_to_ffi_dtype(value)
elif numpy is not None and isinstance(value, numpy.dtype):
return core._convert_numpy_dtype_to_ffi_dtype(value)
+ elif hasattr(value, "__dlpack_data_type__"):
+ cdtype = core._create_cdtype_from_tuple(core.DataType,
*value.__dlpack_data_type__())
+ dtype = str.__new__(_dtype.dtype, str(cdtype))
+ dtype._tvm_ffi_dtype = cdtype
+ return dtype
elif isinstance(value, Exception):
return core._convert_to_ffi_error(value)
+ elif hasattr(value, "__tvm_ffi_object__"):
+ return value.__tvm_ffi_object__()
+ # keep rest protocol values as it is as they can be handled by ffi function
+ elif hasattr(value, "__cuda_stream__"):
+ return value
+ elif hasattr(value, "__tvm_ffi_opaque_ptr__"):
+ return value
+ elif hasattr(value, "__dlpack_device__"):
+ return value
else:
# in this case, it is an opaque python object
return core._convert_to_opaque_object(value)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index ebc2585..7fa78d5 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -129,7 +129,7 @@ class dtype(str):
Create vector dtypes from a scalar base.
"""
- cdtype = core._create_dtype_from_tuple(
+ cdtype = core._create_cdtype_from_tuple(
core.DataType,
dltype_data_type[0],
dltype_data_type[1],
@@ -172,7 +172,7 @@ class dtype(str):
Construct from a DLPack ``(code, bits, lanes)`` triple.
"""
- cdtype = core._create_dtype_from_tuple(
+ cdtype = core._create_cdtype_from_tuple(
core.DataType,
self._tvm_ffi_dtype.type_code,
self._tvm_ffi_dtype.bits,
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 7673c03..29c1db3 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -102,9 +102,10 @@ def load_torch_c_dlpack_extension() -> Any:
return lib
except ImportError:
pass
- except Exception as e:
+ except Exception:
warnings.warn(
- f"Failed to load torch c dlpack extension, EnvTensorAllocator will
not be enabled:\n {e}"
+ "Failed to JIT torch c dlpack extension, EnvTensorAllocator will
not be enabled.\n"
+ "You may try AOT-module via `pip install torch-c-dlpack-ext`"
)
return None
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 853af72..d5d43b7 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -103,7 +103,9 @@ class DataType:
def _set_class_dtype(cls: type) -> None: ...
def _convert_torch_dtype_to_ffi_dtype(torch_dtype: Any) -> DataType: ...
def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype: Any) -> DataType: ...
-def _create_dtype_from_tuple(cls: type[DataType], code: int, bits: int, lanes:
int) -> DataType: ...
+def _create_cdtype_from_tuple(
+ cls: type[DataType], code: int, bits: int, lanes: int
+) -> DataType: ...
class DLDeviceType(IntEnum):
kDLCPU = 1
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 15f9418..3f3e460 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -51,12 +51,12 @@ def _set_class_dtype(cls):
_CLASS_DTYPE = cls
-def _create_dtype_from_tuple(cls, code, bits, lanes):
+def _create_cdtype_from_tuple(cls, code, bits, lanes):
cdef DLDataType cdtype
cdtype.code = code
cdtype.bits = bits
cdtype.lanes = lanes
- ret = cls.__new__(cls, str(cdtype))
+ ret = cls.__new__(cls)
(<DataType>ret).cdtype = cdtype
return ret
@@ -87,7 +87,7 @@ cdef class DataType:
def __reduce__(self) -> Any:
cls = type(self)
- return (_create_dtype_from_tuple,
+ return (_create_cdtype_from_tuple,
(cls, self.cdtype.code, self.cdtype.bits, self.cdtype.lanes))
def __eq__(self, other: object) -> bool:
@@ -102,6 +102,9 @@ cdef class DataType:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
+ def __hash__(self) -> int:
+ return hash((self.cdtype.code, self.cdtype.bits, self.cdtype.lanes))
+
@property
def type_code(self) -> int:
"""Integer DLDataTypeCode of the scalar base type."""
@@ -151,20 +154,24 @@ cdef class DataType:
return res
-cdef inline object make_ret_dtype(TVMFFIAny result):
+cdef inline object make_dtype_from_dl_data_type(DLDataType dl_data_type):
cdtype = DataType.__new__(DataType)
- (<DataType>cdtype).cdtype = result.v_dtype
+ (<DataType>cdtype).cdtype = dl_data_type
val = str.__new__(_CLASS_DTYPE, cdtype.__str__())
val._tvm_ffi_dtype = cdtype
return val
-cdef TORCH_DTYPE_TO_DTYPE = {}
-cdef NUMPY_DTYPE_TO_DTYPE = {}
-cdef MLDTYPES_DTYPE_TO_DTYPE = {}
+cdef inline object make_ret_dtype(TVMFFIAny result):
+ return make_dtype_from_dl_data_type(result.v_dtype)
+
+
+cdef TORCH_DTYPE_TO_DL_DATA_TYPE = {}
+cdef NUMPY_DTYPE_TO_DL_DATA_TYPE = {}
+cdef MLDTYPES_DTYPE_TO_DL_DATA_TYPE = {}
if torch is not None:
- TORCH_DTYPE_TO_DTYPE = {
+ TORCH_DTYPE_TO_DL_DATA_TYPE = {
torch.int8: DLDataType(0, 8, 1),
torch.short: DLDataType(0, 16, 1),
torch.int16: DLDataType(0, 16, 1),
@@ -190,21 +197,19 @@ if torch is not None:
torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
}
if hasattr(torch, "float8_e8m0fnu"):
- TORCH_DTYPE_TO_DTYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 1)
+ TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float8_e8m0fnu] = DLDataType(14, 8,
1)
if hasattr(torch, "float4_e2m1fn_x2"):
- TORCH_DTYPE_TO_DTYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 4, 2)
+ TORCH_DTYPE_TO_DL_DATA_TYPE[torch.float4_e2m1fn_x2] = DLDataType(17,
4, 2)
def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
- cdef DLDataType cdtype = TORCH_DTYPE_TO_DTYPE[torch_dtype]
- ret = DataType.__new__(DataType, str(cdtype))
- (<DataType>ret).cdtype = cdtype
- return ret
+ cdef DLDataType dl_data_type = TORCH_DTYPE_TO_DL_DATA_TYPE[torch_dtype]
+ return make_dtype_from_dl_data_type(dl_data_type)
else:
def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
raise ValueError("torch not found")
if ml_dtypes is not None:
- MLDTYPES_DTYPE_TO_DTYPE = {
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE = {
numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
@@ -216,19 +221,19 @@ if ml_dtypes is not None:
}
if hasattr(ml_dtypes, "int2"): # ml_dtypes >= 0.5.0
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0,
2, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = DLDataType(1,
2, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.int2)] =
DLDataType(0, 2, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.uint2)] =
DLDataType(1, 2, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] =
DLDataType(7, 8, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] =
DLDataType(8, 8, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] =
DLDataType(14, 8, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] =
DLDataType(15, 6, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] =
DLDataType(16, 6, 1)
- MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] =
DLDataType(17, 4, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e3m4)] =
DLDataType(7, 8, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e4m3)] =
DLDataType(8, 8, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)]
= DLDataType(14, 8, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] =
DLDataType(15, 6, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] =
DLDataType(16, 6, 1)
+ MLDTYPES_DTYPE_TO_DL_DATA_TYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] =
DLDataType(17, 4, 1)
if numpy is not None:
- NUMPY_DTYPE_TO_DTYPE = {
+ NUMPY_DTYPE_TO_DL_DATA_TYPE = {
numpy.dtype(numpy.int8): DLDataType(0, 8, 1),
numpy.dtype(numpy.int16): DLDataType(0, 16, 1),
numpy.dtype(numpy.int32): DLDataType(0, 32, 1),
@@ -240,14 +245,12 @@ if numpy is not None:
numpy.dtype(numpy.float16): DLDataType(2, 16, 1),
numpy.dtype(numpy.float32): DLDataType(2, 32, 1),
numpy.dtype(numpy.float64): DLDataType(2, 64, 1),
- **MLDTYPES_DTYPE_TO_DTYPE,
+ **MLDTYPES_DTYPE_TO_DL_DATA_TYPE,
}
def _convert_numpy_dtype_to_ffi_dtype(numpy_dtype):
- cdef DLDataType cdtype = NUMPY_DTYPE_TO_DTYPE[numpy_dtype]
- ret = DataType.__new__(DataType, str(cdtype))
- (<DataType>ret).cdtype = cdtype
- return ret
+ cdef DLDataType cdtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[numpy_dtype]
+ return make_dtype_from_dl_data_type(cdtype)
else:
def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
raise ValueError("numpy not found")
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index acfe3e2..3ebe862 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -188,7 +188,7 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
return 0
-cdef int TorchDLPackToPyObjectFallback_(
+cdef int TorchManagedTensorToPyObjectNoSyncFallback_(
DLManagedTensorVersioned* dltensor, void** py_obj_out
) except -1:
# a bit convoluted but ok as a fallback
@@ -211,7 +211,9 @@ cdef inline const DLPackExchangeAPI*
GetTorchFallbackExchangeAPI() noexcept:
_torch_fallback_exchange_api.header.prev_api = NULL
_torch_fallback_exchange_api.managed_tensor_allocator = NULL
_torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL
- _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync =
TorchDLPackToPyObjectFallback_
+ _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = (
+ TorchManagedTensorToPyObjectNoSyncFallback_
+ )
_torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL
_torch_fallback_exchange_api.current_work_stream = NULL
@@ -228,6 +230,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
"""Current setter for torch.Tensor, go through python and not as fast as c
exporter"""
# TODO(tqchen): remove this once torch always support fast DLPack importer
cdef object arg = <object>py_arg
+ cdef long long temp_ptr
is_cuda = arg.is_cuda
arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg))
out.type_index = kTVMFFITensor
@@ -235,7 +238,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
# record the stream and device for torch context
- if is_cuda and ctx.device_type != -1:
+ if is_cuda and ctx.device_type == -1:
ctx.device_type = temp_dltensor.device.device_type
ctx.device_id = temp_dltensor.device.device_id
# This is an API that dynamo and other uses to get the raw stream from
torch
@@ -587,10 +590,10 @@ cdef int TVMFFIPyArgSetterDTypeFromTorch_(
) except -1:
"""Setter for torch dtype"""
cdef py_obj = <object>py_arg
- if py_obj not in TORCH_DTYPE_TO_DTYPE:
+ if py_obj not in TORCH_DTYPE_TO_DL_DATA_TYPE:
raise ValueError("Unsupported torch dtype: ", py_obj)
out.type_index = kTVMFFIDataType
- out.v_dtype = TORCH_DTYPE_TO_DTYPE[py_obj]
+ out.v_dtype = TORCH_DTYPE_TO_DL_DATA_TYPE[py_obj]
return 0
cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
@@ -599,10 +602,10 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
) except -1:
"""Setter for torch dtype"""
cdef py_obj = <object>py_arg
- if py_obj not in NUMPY_DTYPE_TO_DTYPE:
+ if py_obj not in NUMPY_DTYPE_TO_DL_DATA_TYPE:
raise ValueError("Unsupported numpy or ml_dtypes dtype: ", py_obj)
out.type_index = kTVMFFIDataType
- out.v_dtype = NUMPY_DTYPE_TO_DTYPE[py_obj]
+ out.v_dtype = NUMPY_DTYPE_TO_DL_DATA_TYPE[py_obj]
return 0
cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
@@ -667,7 +670,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# as a member variable
out.func = TVMFFIPyArgSetterFFIObjectCompatible_
return 0
- if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
+ if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
# Check for DLPackExchangeAPI struct (new approach)
# This is checked on the CLASS, not the instance
if hasattr(arg_class, "__c_dlpack_exchange_api__"):
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 60d53ef..cae5f84 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -88,10 +88,12 @@ _fecho = tvm_ffi.get_global_func("testing.echo")
def _check_dtype(dtype: Any, code: int, bits: int, lanes: int) -> None:
echo_dtype = _fecho(dtype)
+ assert isinstance(echo_dtype, tvm_ffi.dtype)
assert echo_dtype.type_code == code
assert echo_dtype.bits == bits
assert echo_dtype.lanes == lanes
converted_dtype = tvm_ffi.convert(dtype)
+ assert isinstance(converted_dtype, tvm_ffi.dtype)
assert converted_dtype.type_code == code
assert converted_dtype.bits == bits
assert converted_dtype.lanes == lanes
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index deef0f5..a1e68b0 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -332,6 +332,8 @@ def test_function_with_dlpack_data_type_protocol() -> None:
x = DLPackDataTypeProtocol((dtype.type_code, dtype.bits, dtype.lanes))
y = fecho(x)
assert y == dtype
+ converted_y = tvm_ffi.convert(x)
+ assert converted_y == dtype
def test_function_with_dlpack_device_protocol() -> None: