This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new be9a827 fix(py_class): support `tvm_ffi.dtype` and `tvm_ffi.Device`
as field type annotations (#540)
be9a827 is described below
commit be9a8274623252b48ac252961f5968255bf51114
Author: Junru Shao <[email protected]>
AuthorDate: Sun Apr 12 13:05:04 2026 -0700
fix(py_class): support `tvm_ffi.dtype` and `tvm_ffi.Device` as field type
annotations (#540)
## Summary
- `TypeSchema.from_annotation()` in the Cython layer only recognized the
C-level `DataType`/`Device` cdef classes for dtype/device annotations
(`annotation is DataType`). The public Python wrapper classes —
`tvm_ffi.dtype` (`class dtype(str)` in `_dtype.py`) and `tvm_ffi.Device`
— are distinct types, so using them as `@py_class` field type
annotations raised `TypeError: Cannot convert <class '...'> to
TypeSchema`.
- The fix extends the identity checks to also match the Python wrapper
classes via the existing `_CLASS_DTYPE` / `_CLASS_DEVICE` module-level
sentinels.
- Added 6 regression tests covering dtype fields, Device fields,
combined usage, setter mutation, and `Optional` variants.
## Test plan
- [x] `uv run pytest -vvs tests/python/test_dataclass_py_class.py` — all
342 tests pass (including 6 new `TestDtypeDeviceFields` tests)
- [ ] CI: lint, C++ tests, full Python test suite, Rust tests
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
python/tvm_ffi/cython/type_info.pxi | 4 +--
tests/python/test_dataclass_py_class.py | 64 +++++++++++++++++++++++++++++++++
2 files changed, 66 insertions(+), 2 deletions(-)
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 67be0d8..118e067 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -404,9 +404,9 @@ class TypeSchema:
return TypeSchema("bytes")
# --- Non-CObject cdef classes with known origins ---
- if annotation is DataType:
+ if annotation is DataType or (_CLASS_DTYPE is not None and annotation
is _CLASS_DTYPE):
return TypeSchema("dtype")
- if annotation is Device:
+ if annotation is Device or (_CLASS_DEVICE is not None and annotation
is _CLASS_DEVICE):
return TypeSchema("Device")
# --- ctypes.c_void_p ---
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index 57a7c1c..a9c5358 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4969,3 +4969,67 @@ class TestSuperInitPattern:
assert obj3.x == 42
assert obj3.y == "hello"
assert not obj.same_as(obj3)
+
+
+class TestDtypeDeviceFields:
+ """Regression: @py_class should accept tvm_ffi.dtype and tvm_ffi.Device as
field types."""
+
+ def test_dtype_field(self) -> None:
+ @py_class(_unique_key("DtypeField"))
+ class DtypeHolder(Object):
+ dt: tvm_ffi.dtype
+
+ obj = DtypeHolder(dt=tvm_ffi.dtype("float32"))
+ assert obj.dt == "float32"
+ assert isinstance(obj.dt, tvm_ffi.dtype)
+
+ def test_dtype_field_setter(self) -> None:
+ @py_class(_unique_key("DtypeFieldSet"))
+ class DtypeHolder2(Object):
+ dt: tvm_ffi.dtype
+
+ obj = DtypeHolder2(dt=tvm_ffi.dtype("float32"))
+ obj.dt = tvm_ffi.dtype("int8")
+ assert obj.dt == "int8"
+
+ def test_device_field(self) -> None:
+ @py_class(_unique_key("DeviceField"))
+ class DeviceHolder(Object):
+ dev: tvm_ffi.Device
+
+ dev = tvm_ffi.device("cpu", 0)
+ obj = DeviceHolder(dev=dev)
+ assert obj.dev == dev
+
+ def test_dtype_device_together(self) -> None:
+ @py_class(_unique_key("DtypeDeviceTogether"))
+ class DtypeDeviceHolder(Object):
+ dt: tvm_ffi.dtype
+ dev: tvm_ffi.Device
+ name: str
+
+ dev = tvm_ffi.device("cpu", 0)
+ obj = DtypeDeviceHolder(dt=tvm_ffi.dtype("float16"), dev=dev,
name="test")
+ assert obj.dt == "float16"
+ assert obj.dev == dev
+ assert obj.name == "test"
+
+ def test_optional_dtype_field(self) -> None:
+ @py_class(_unique_key("OptDtype"))
+ class OptDtype(Object):
+ dt: Optional[tvm_ffi.dtype] = None
+
+ obj_none = OptDtype()
+ assert obj_none.dt is None
+ obj_val = OptDtype(dt=tvm_ffi.dtype("bfloat16"))
+ assert obj_val.dt == "bfloat16"
+
+ def test_optional_device_field(self) -> None:
+ @py_class(_unique_key("OptDevice"))
+ class OptDevice(Object):
+ dev: Optional[tvm_ffi.Device] = None
+
+ obj_none = OptDevice()
+ assert obj_none.dev is None
+ obj_val = OptDevice(dev=tvm_ffi.device("cpu", 0))
+ assert obj_val.dev == tvm_ffi.device("cpu", 0)