This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new a7ebc65 Allow `tvm_ffi.device(..., id)` where id is numpy or torch
scalar (#347)
a7ebc65 is described below
commit a7ebc65f14eecd1592d407f1d5c952c65603a9aa
Author: wrongtest <[email protected]>
AuthorDate: Thu Dec 18 02:17:15 2025 +0800
Allow `tvm_ffi.device(..., id)` where id is numpy or torch scalar (#347)
Similar to torch.device
```python
torch.device("cuda", 0)
torch.device("cuda", numpy.int32(1))
torch.device("cuda", torch.tensor(1, dtype=torch.int32))
```
```python
tvm_ffi.device("cuda", 0)
tvm_ffi.device("cuda", numpy.int32(1))
tvm_ffi.device("cuda", torch.tensor(1, dtype=torch.int32))
```
---------
Co-authored-by: baoxinqi <[email protected]>
---
python/tvm_ffi/cython/device.pxi | 11 ++++++++---
tests/python/test_device.py | 5 +++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 9539827..2eb36fc 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from enum import IntEnum
+from numbers import Integral
from typing import Any, Optional
_CLASS_DEVICE = None
@@ -129,7 +130,7 @@ cdef class Device:
"trn": DLDeviceType.kDLTrn,
}
- def __init__(self, device_type: str | int, index: Optional[int] = None) ->
None:
+ def __init__(self, device_type: str | int, index: Optional[Integral] =
None) -> None:
device_type_or_name = device_type
index = index if index is not None else 0
if isinstance(device_type_or_name, str):
@@ -148,8 +149,12 @@ cdef class Device:
raise ValueError(f"Invalid device index: {parts[1]}")
else:
device_type = device_type_or_name
- if not isinstance(index, int):
- raise TypeError(f"Invalid device index: {index}")
+
+ if not isinstance(index, Integral):
+ if hasattr(index, "item") and callable(index.item):
+ index = index.item()
+ if not isinstance(index, Integral):
+ raise TypeError(f"Invalid device index: {index}")
self.cdevice = TVMFFIDLDeviceFromIntPair(device_type, index)
def __reduce__(self) -> Any:
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 71dc0e4..33b48e8 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import ctypes
import pickle
+import numpy
import pytest
import tvm_ffi
from tvm_ffi import DLDeviceType
@@ -69,6 +70,10 @@ def test_device_dlpack_device_type(
(DLDeviceType.kDLCUDA, 0, DLDeviceType.kDLCUDA, 0),
("cuda", 3, DLDeviceType.kDLCUDA, 3),
(DLDeviceType.kDLMetal, 2, DLDeviceType.kDLMetal, 2),
+ # id from numpy
+ ("cpu", numpy.int32(1), DLDeviceType.kDLCPU, 1),
+ # id from torch (py dependency not ready in environment)
+ # ("cpu", torch.tensor(1, dtype=torch.int32), DLDeviceType.kDLCPU, 1),
],
)
def test_device_with_dev_id(