This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4dbc260 [BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch
extension loader (#585)
4dbc260 is described below
commit 4dbc2602ff1463d0ce14e2cc48218702e98262d6
Author: ZihaoMu <[email protected]>
AuthorDate: Wed May 13 22:41:57 2026 +0800
[BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader
(#585)
Motivation:
- If PyTorch already provides `__dlpack_c_exchange_api__`, tvm-ffi uses
it directly on all backends.
- The optional `torch-c-dlpack-ext` path is only used as a fallback for
older PyTorch builds that do not provide the API.
- The fallback extension library name is selected with backend-aware
detection for CPU, CUDA, and ROCm.
Tests are added for backend detection, ROCm short-circuit behavior when
PyTorch already provides the API, and GPU tensor metadata through
DLPack.
Related PR: https://github.com/tile-ai/tilelang/pull/2179, I have
finished the A/B test locally.
---
.../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py | 15 +++-
python/tvm_ffi/_optional_torch_c_dlpack.py | 28 ++++----
tests/python/test_dlpack_exchange_api.py | 22 ++++++
tests/python/test_optional_torch_c_dlpack.py | 81 +++++++++++++++++++++-
4 files changed, 128 insertions(+), 18 deletions(-)
diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index 6b63e0a..9d104d7 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -39,6 +39,17 @@ def _create_dlpack_exchange_api_capsule(ptr_as_int: int) ->
Any:
return capsule
+def _torch_extension_device(torch_module: Any) -> str:
+ """Return the torch backend name used in the optional extension library
name."""
+ if torch_module.cuda.is_available():
+ if getattr(torch_module.version, "cuda", None) is not None:
+ return "cuda"
+ if getattr(torch_module.version, "hip", None) is not None:
+ return "rocm"
+ return "cuda"
+ return "cpu"
+
+
def load_torch_c_dlpack_extension() -> None:
"""Load the torch c dlpack extension based on torch version."""
if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
@@ -52,10 +63,10 @@ def load_torch_c_dlpack_extension() -> None:
extension = "dylib"
else:
extension = "so"
- suffix = "cuda" if torch.cuda.is_available() else "cpu"
+ device = _torch_extension_device(torch)
lib_path = (
Path(__file__).parent
- /
f"libtorch_c_dlpack_addon_torch{version.major}{version.minor}-{suffix}.{extension}"
+ /
f"libtorch_c_dlpack_addon_torch{version.major}{version.minor}-{device}.{extension}"
)
if not lib_path.exists() or not lib_path.is_file():
raise ImportError("No matching prebuilt torch c dlpack extension")
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 7012108..5c804f1 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -44,6 +44,17 @@ from typing import Any
logger = logging.getLogger(__name__)
+def _torch_extension_device(torch_module: Any) -> str:
+ """Return the torch backend name used in the optional extension library
name."""
+ if torch_module.cuda.is_available():
+ if getattr(torch_module.version, "cuda", None) is not None:
+ return "cuda"
+ if getattr(torch_module.version, "hip", None) is not None:
+ return "rocm"
+ return "cuda"
+ return "cpu"
+
+
def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
"""Create a PyCapsule wrapping the DLPack exchange API pointer.
@@ -94,8 +105,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
import torch # noqa: PLC0415
import torch.version # noqa: PLC0415
- prefer_rocm_override = bool(torch.cuda.is_available() and
torch.version.hip is not None)
- if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not
prefer_rocm_override:
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
# skip loading the extension if the __dlpack_c_exchange_api__
# attribute is already set so we don't have to do it in
# newer version of PyTorch
@@ -107,7 +117,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
try:
import torch_c_dlpack_ext # noqa: PLC0415, F401
- if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not
prefer_rocm_override:
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
return None
except ImportError:
pass
@@ -122,17 +132,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
cache_dir = Path(os.environ.get("TVM_FFI_CACHE_DIR",
"~/.cache/tvm-ffi")).expanduser()
addon_output_dir = cache_dir
major, minor = torch.__version__.split(".")[:2]
- # First use "torch.cuda.is_available()" to check whether GPU
environment
- # is available. Then determine the GPU type.
- if torch.cuda.is_available():
- if torch.version.cuda is not None:
- device = "cuda"
- elif torch.version.hip is not None:
- device = "rocm"
- else:
- raise ValueError("Cannot determine whether to build with CUDA
or ROCm.")
- else:
- device = "cpu"
+ device = _torch_extension_device(torch)
suffix = ".dll" if sys.platform.startswith("win") else ".so"
libname =
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
lib_path = addon_output_dir / libname
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 0891e11..0938a25 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -36,6 +36,7 @@ except ImportError:
# Check if DLPack Exchange API is available
_has_dlpack_api = torch is not None and hasattr(torch.Tensor,
"__dlpack_c_exchange_api__")
+_has_gpu = torch is not None and torch.cuda.is_available()
@pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API
not available")
@@ -214,6 +215,27 @@ def test_dlpack_exchange_api() -> None:
mod.test_dlpack_api(tensor, api_ptr, torch.cuda.is_available())
[email protected](
+ not (_has_dlpack_api and _has_gpu),
+ reason="PyTorch DLPack Exchange API with GPU is not available",
+)
+def test_dlpack_exchange_api_gpu_tensor_metadata() -> None:
+ assert torch is not None
+ echo = tvm_ffi.get_global_func("testing.echo")
+
+ for shape in [(512,), (512, 512), (2, 3, 4)]:
+ source = torch.empty(shape, device="cuda", dtype=torch.float16)
+
+ tvm_tensor = tvm_ffi.from_dlpack(source)
+ assert tvm_tensor.shape == shape
+ assert tvm_tensor.dtype == tvm_ffi.dtype("float16")
+
+ echoed = echo(source)
+ assert tuple(echoed.shape) == shape
+ assert echoed.dtype == source.dtype
+ assert echoed.device == source.device
+
+
@pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API
not available")
def test_from_dlpack_torch() -> None:
# Covers from_dlpack to use fallback fastpath
diff --git a/tests/python/test_optional_torch_c_dlpack.py
b/tests/python/test_optional_torch_c_dlpack.py
index 6640d6f..2a27892 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -15,10 +15,15 @@
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+import builtins
import ctypes
import subprocess
import sys
from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
import pytest
@@ -30,10 +35,81 @@ except ImportError:
import tvm_ffi
+from tvm_ffi import _optional_torch_c_dlpack
IS_WINDOWS = sys.platform.startswith("win")
+def _fake_torch_module(
+ *,
+ cuda_available: bool,
+ cuda_version: str | None = None,
+ hip_version: str | None = None,
+ include_cuda_attr: bool = True,
+ include_hip_attr: bool = True,
+) -> Any:
+ version = SimpleNamespace()
+ if include_cuda_attr:
+ version.cuda = cuda_version
+ if include_hip_attr:
+ version.hip = hip_version
+ return SimpleNamespace(
+ cuda=SimpleNamespace(is_available=lambda: cuda_available),
+ version=version,
+ )
+
+
+def test_torch_extension_device() -> None:
+ assert (
+ _optional_torch_c_dlpack._torch_extension_device(
+ _fake_torch_module(cuda_available=False, cuda_version=None,
hip_version=None)
+ )
+ == "cpu"
+ )
+ assert (
+ _optional_torch_c_dlpack._torch_extension_device(
+ _fake_torch_module(cuda_available=True, cuda_version="12.8",
hip_version=None)
+ )
+ == "cuda"
+ )
+ assert (
+ _optional_torch_c_dlpack._torch_extension_device(
+ _fake_torch_module(cuda_available=True, cuda_version=None,
hip_version="7.2")
+ )
+ == "rocm"
+ )
+ assert (
+ _optional_torch_c_dlpack._torch_extension_device(
+ _fake_torch_module(
+ cuda_available=True,
+ include_cuda_attr=False,
+ include_hip_attr=False,
+ )
+ )
+ == "cuda"
+ )
+
+
+def test_existing_torch_dlpack_api_is_preferred_on_rocm(monkeypatch:
pytest.MonkeyPatch) -> None:
+ torch_module = SimpleNamespace(
+ cuda=SimpleNamespace(is_available=lambda: True),
+ version=SimpleNamespace(cuda=None, hip="7.2"),
+ Tensor=SimpleNamespace(__dlpack_c_exchange_api__=object()),
+ )
+ original_import = builtins.__import__
+
+ def guarded_import(name: str, *args: Any, **kwargs: Any) -> Any:
+ if name == "torch_c_dlpack_ext":
+ raise AssertionError("torch_c_dlpack_ext should not be imported")
+ return original_import(name, *args, **kwargs)
+
+ monkeypatch.setitem(sys.modules, "torch", torch_module)
+ monkeypatch.setitem(sys.modules, "torch.version", torch_module.version)
+ monkeypatch.setattr(builtins, "__import__", guarded_import)
+
+ assert _optional_torch_c_dlpack.load_torch_c_dlpack_extension() is None
+
+
@pytest.mark.skipif(torch is None, reason="torch is not installed")
def test_build_torch_c_dlpack_extension() -> None:
assert torch is not None
@@ -49,9 +125,10 @@ def test_build_torch_c_dlpack_extension() -> None:
# First use "torch.cuda.is_available()" to check whether GPU environment
# is available. Then determine the GPU type.
if torch.cuda.is_available():
- if torch.version.cuda is not None:
+ device = _optional_torch_c_dlpack._torch_extension_device(torch)
+ if device == "cuda":
args.append("--build-with-cuda")
- elif torch.version.hip is not None:
+ elif device == "rocm":
args.append("--build-with-rocm")
else:
raise ValueError("Cannot determine whether to build with CUDA or
ROCm.")