This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c0add28 [FIX] Fix device type override (#40)
c0add28 is described below
commit c0add281b0973aacc88f3185a84e0768c67306f4
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 22 11:23:53 2025 -0400
[FIX] Fix device type override (#40)
This pr fixes the behavior of device so device type override works
correctly, added an unittest.
---
pyproject.toml | 2 +-
python/tvm_ffi/_tensor.py | 6 +++---
tests/python/test_device.py | 12 ++++++++++++
tests/python/test_tensor.py | 18 ++++++++++++++++++
4 files changed, 34 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1e97de7..d2fd897 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b5"
+version = "0.1.0b6"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index 8d06bd2..0cc09f1 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -21,9 +21,8 @@
from numbers import Integral
from typing import Any, Optional, Union
-from . import _ffi_api, registry
+from . import _ffi_api, core, registry
from .core import (
- _CLASS_DEVICE,
Device,
DLDeviceType,
PyNativeObject,
@@ -86,7 +85,8 @@ def device(device_type: Union[str, int, DLDeviceType], index:
Optional[int] = No
assert tvm_ffi.device("cpu:0") == tvm_ffi.device("cpu", 0)
"""
- return _CLASS_DEVICE(device_type, index)
+ # must refer to core._CLASS_DEVICE so we pick up override here
+ return core._CLASS_DEVICE(device_type, index)
__all__ = ["DLDeviceType", "Device", "Tensor", "device", "from_dlpack"]
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 30c964a..9441c9f 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -95,3 +95,15 @@ def test_device_pickle() -> None:
device_pickled = pickle.loads(pickle.dumps(device))
assert device_pickled.dlpack_device_type() == device.dlpack_device_type()
assert device_pickled.index == device.index
+
+
+def test_device_class_override() -> None:
+ class MyDevice(tvm_ffi.Device):
+ pass
+
+ old_device = tvm_ffi.core._CLASS_DEVICE
+ tvm_ffi.core._set_class_device(MyDevice)
+
+ device = tvm_ffi.device("cuda", 0)
+ assert isinstance(device, MyDevice)
+ tvm_ffi.core._set_class_device(old_device)
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 6d1da26..4c2e9a8 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -66,3 +66,21 @@ def test_tensor_auto_dlpack() -> None:
assert y.shape == x.shape
assert y.device == x.device
np.testing.assert_equal(y.numpy(), x.numpy())
+
+
+def test_tensor_class_override() -> None:
+ class MyTensor(tvm_ffi.Tensor):
+ pass
+
+ old_tensor = tvm_ffi.core._CLASS_TENSOR
+ tvm_ffi.core._set_class_tensor(MyTensor)
+
+ data = np.zeros((10, 8, 4, 2), dtype="int16")
+ if not hasattr(data, "__dlpack__"):
+ return
+ x = tvm_ffi.from_dlpack(data)
+
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ y = fecho(x)
+ assert isinstance(y, MyTensor)
+ tvm_ffi.core._set_class_tensor(old_tensor)