This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dbc260  [BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch 
extension loader (#585)
4dbc260 is described below

commit 4dbc2602ff1463d0ce14e2cc48218702e98262d6
Author: ZihaoMu <[email protected]>
AuthorDate: Wed May 13 22:41:57 2026 +0800

    [BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader 
(#585)
    
    Motivation:
    - If PyTorch already provides `__dlpack_c_exchange_api__`, tvm-ffi uses
    it directly on all backends.
    - The optional `torch-c-dlpack-ext` path is only used as a fallback for
    older PyTorch builds that do not provide the API.
    - The fallback extension library name is selected with backend-aware
    detection for CPU, CUDA, and ROCm.
    
    Tests are added for backend detection, ROCm short-circuit behavior when
    PyTorch already provides the API, and GPU tensor metadata through
    DLPack.
    
    
    Related PR: https://github.com/tile-ai/tilelang/pull/2179, I have
    finished the A/B test locally.
---
 .../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py  | 15 +++-
 python/tvm_ffi/_optional_torch_c_dlpack.py         | 28 ++++----
 tests/python/test_dlpack_exchange_api.py           | 22 ++++++
 tests/python/test_optional_torch_c_dlpack.py       | 81 +++++++++++++++++++++-
 4 files changed, 128 insertions(+), 18 deletions(-)

diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py 
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index 6b63e0a..9d104d7 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -39,6 +39,17 @@ def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> 
Any:
     return capsule
 
 
+def _torch_extension_device(torch_module: Any) -> str:
+    """Return the torch backend name used in the optional extension library 
name."""
+    if torch_module.cuda.is_available():
+        if getattr(torch_module.version, "cuda", None) is not None:
+            return "cuda"
+        if getattr(torch_module.version, "hip", None) is not None:
+            return "rocm"
+        return "cuda"
+    return "cpu"
+
+
 def load_torch_c_dlpack_extension() -> None:
     """Load the torch c dlpack extension based on torch version."""
     if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
@@ -52,10 +63,10 @@ def load_torch_c_dlpack_extension() -> None:
         extension = "dylib"
     else:
         extension = "so"
-    suffix = "cuda" if torch.cuda.is_available() else "cpu"
+    device = _torch_extension_device(torch)
     lib_path = (
         Path(__file__).parent
-        / 
f"libtorch_c_dlpack_addon_torch{version.major}{version.minor}-{suffix}.{extension}"
+        / 
f"libtorch_c_dlpack_addon_torch{version.major}{version.minor}-{device}.{extension}"
     )
     if not lib_path.exists() or not lib_path.is_file():
         raise ImportError("No matching prebuilt torch c dlpack extension")
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 7012108..5c804f1 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -44,6 +44,17 @@ from typing import Any
 logger = logging.getLogger(__name__)
 
 
+def _torch_extension_device(torch_module: Any) -> str:
+    """Return the torch backend name used in the optional extension library 
name."""
+    if torch_module.cuda.is_available():
+        if getattr(torch_module.version, "cuda", None) is not None:
+            return "cuda"
+        if getattr(torch_module.version, "hip", None) is not None:
+            return "rocm"
+        return "cuda"
+    return "cpu"
+
+
 def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
     """Create a PyCapsule wrapping the DLPack exchange API pointer.
 
@@ -94,8 +105,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
         import torch  # noqa: PLC0415
         import torch.version  # noqa: PLC0415
 
-        prefer_rocm_override = bool(torch.cuda.is_available() and 
torch.version.hip is not None)
-        if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not 
prefer_rocm_override:
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
             # skip loading the extension if the __dlpack_c_exchange_api__
             # attribute is already set so we don't have to do it in
             # newer version of PyTorch
@@ -107,7 +117,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
     try:
         import torch_c_dlpack_ext  # noqa: PLC0415, F401
 
-        if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not 
prefer_rocm_override:
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
             return None
     except ImportError:
         pass
@@ -122,17 +132,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
         cache_dir = Path(os.environ.get("TVM_FFI_CACHE_DIR", 
"~/.cache/tvm-ffi")).expanduser()
         addon_output_dir = cache_dir
         major, minor = torch.__version__.split(".")[:2]
-        # First use "torch.cuda.is_available()" to check whether GPU 
environment
-        # is available. Then determine the GPU type.
-        if torch.cuda.is_available():
-            if torch.version.cuda is not None:
-                device = "cuda"
-            elif torch.version.hip is not None:
-                device = "rocm"
-            else:
-                raise ValueError("Cannot determine whether to build with CUDA 
or ROCm.")
-        else:
-            device = "cpu"
+        device = _torch_extension_device(torch)
         suffix = ".dll" if sys.platform.startswith("win") else ".so"
         libname = 
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
         lib_path = addon_output_dir / libname
diff --git a/tests/python/test_dlpack_exchange_api.py 
b/tests/python/test_dlpack_exchange_api.py
index 0891e11..0938a25 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -36,6 +36,7 @@ except ImportError:
 
 # Check if DLPack Exchange API is available
 _has_dlpack_api = torch is not None and hasattr(torch.Tensor, 
"__dlpack_c_exchange_api__")
+_has_gpu = torch is not None and torch.cuda.is_available()
 
 
 @pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API 
not available")
@@ -214,6 +215,27 @@ def test_dlpack_exchange_api() -> None:
     mod.test_dlpack_api(tensor, api_ptr, torch.cuda.is_available())
 
 
[email protected](
+    not (_has_dlpack_api and _has_gpu),
+    reason="PyTorch DLPack Exchange API with GPU is not available",
+)
+def test_dlpack_exchange_api_gpu_tensor_metadata() -> None:
+    assert torch is not None
+    echo = tvm_ffi.get_global_func("testing.echo")
+
+    for shape in [(512,), (512, 512), (2, 3, 4)]:
+        source = torch.empty(shape, device="cuda", dtype=torch.float16)
+
+        tvm_tensor = tvm_ffi.from_dlpack(source)
+        assert tvm_tensor.shape == shape
+        assert tvm_tensor.dtype == tvm_ffi.dtype("float16")
+
+        echoed = echo(source)
+        assert tuple(echoed.shape) == shape
+        assert echoed.dtype == source.dtype
+        assert echoed.device == source.device
+
+
 @pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API 
not available")
 def test_from_dlpack_torch() -> None:
     # Covers from_dlpack to use fallback fastpath
diff --git a/tests/python/test_optional_torch_c_dlpack.py 
b/tests/python/test_optional_torch_c_dlpack.py
index 6640d6f..2a27892 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -15,10 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import annotations
+
+import builtins
 import ctypes
 import subprocess
 import sys
 from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
 
 import pytest
 
@@ -30,10 +35,81 @@ except ImportError:
 
 
 import tvm_ffi
+from tvm_ffi import _optional_torch_c_dlpack
 
 IS_WINDOWS = sys.platform.startswith("win")
 
 
+def _fake_torch_module(
+    *,
+    cuda_available: bool,
+    cuda_version: str | None = None,
+    hip_version: str | None = None,
+    include_cuda_attr: bool = True,
+    include_hip_attr: bool = True,
+) -> Any:
+    version = SimpleNamespace()
+    if include_cuda_attr:
+        version.cuda = cuda_version
+    if include_hip_attr:
+        version.hip = hip_version
+    return SimpleNamespace(
+        cuda=SimpleNamespace(is_available=lambda: cuda_available),
+        version=version,
+    )
+
+
+def test_torch_extension_device() -> None:
+    assert (
+        _optional_torch_c_dlpack._torch_extension_device(
+            _fake_torch_module(cuda_available=False, cuda_version=None, 
hip_version=None)
+        )
+        == "cpu"
+    )
+    assert (
+        _optional_torch_c_dlpack._torch_extension_device(
+            _fake_torch_module(cuda_available=True, cuda_version="12.8", 
hip_version=None)
+        )
+        == "cuda"
+    )
+    assert (
+        _optional_torch_c_dlpack._torch_extension_device(
+            _fake_torch_module(cuda_available=True, cuda_version=None, 
hip_version="7.2")
+        )
+        == "rocm"
+    )
+    assert (
+        _optional_torch_c_dlpack._torch_extension_device(
+            _fake_torch_module(
+                cuda_available=True,
+                include_cuda_attr=False,
+                include_hip_attr=False,
+            )
+        )
+        == "cuda"
+    )
+
+
+def test_existing_torch_dlpack_api_is_preferred_on_rocm(monkeypatch: 
pytest.MonkeyPatch) -> None:
+    torch_module = SimpleNamespace(
+        cuda=SimpleNamespace(is_available=lambda: True),
+        version=SimpleNamespace(cuda=None, hip="7.2"),
+        Tensor=SimpleNamespace(__dlpack_c_exchange_api__=object()),
+    )
+    original_import = builtins.__import__
+
+    def guarded_import(name: str, *args: Any, **kwargs: Any) -> Any:
+        if name == "torch_c_dlpack_ext":
+            raise AssertionError("torch_c_dlpack_ext should not be imported")
+        return original_import(name, *args, **kwargs)
+
+    monkeypatch.setitem(sys.modules, "torch", torch_module)
+    monkeypatch.setitem(sys.modules, "torch.version", torch_module.version)
+    monkeypatch.setattr(builtins, "__import__", guarded_import)
+
+    assert _optional_torch_c_dlpack.load_torch_c_dlpack_extension() is None
+
+
 @pytest.mark.skipif(torch is None, reason="torch is not installed")
 def test_build_torch_c_dlpack_extension() -> None:
     assert torch is not None
@@ -49,9 +125,10 @@ def test_build_torch_c_dlpack_extension() -> None:
     # First use "torch.cuda.is_available()" to check whether GPU environment
     # is available. Then determine the GPU type.
     if torch.cuda.is_available():
-        if torch.version.cuda is not None:
+        device = _optional_torch_c_dlpack._torch_extension_device(torch)
+        if device == "cuda":
             args.append("--build-with-cuda")
-        elif torch.version.hip is not None:
+        elif device == "rocm":
             args.append("--build-with-rocm")
         else:
             raise ValueError("Cannot determine whether to build with CUDA or 
ROCm.")

Reply via email to