This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6897a5f fix(cython): Make sure `TypeInfo` is properly registered by
all classes (#246)
6897a5f is described below
commit 6897a5f5a8dc662270dd10be40cf261bd1da93f8
Author: Junru Shao <[email protected]>
AuthorDate: Sat Nov 8 17:54:28 2025 -0800
fix(cython): Make sure `TypeInfo` is properly registered by all classes
(#246)
Before this commit, `parent_type_info` of a few Cython classes is not
properly propagated, because of the registration ordering. This PR fixes
this behavior and adds tests to safeguard this case.
---
python/tvm_ffi/core.pyi | 2 +-
python/tvm_ffi/cython/core.pyx | 16 +++++++++--
python/tvm_ffi/cython/device.pxi | 2 --
python/tvm_ffi/cython/error.pxi | 3 --
python/tvm_ffi/cython/function.pxi | 2 --
python/tvm_ffi/cython/object.pxi | 10 ++++++-
python/tvm_ffi/cython/string.pxi | 7 -----
python/tvm_ffi/cython/tensor.pxi | 1 -
tests/python/test_object.py | 59 ++++++++++++++++++++++++++++++++++++++
9 files changed, 83 insertions(+), 19 deletions(-)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 8141748..261803c 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -62,12 +62,12 @@ class PyNativeObject:
__slots__: list[str]
def __init_cached_object_by_constructor__(self, fconstructor: Any, *args:
Any) -> None: ...
-def _set_class_object(cls: type) -> None: ...
def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
def _object_type_key_to_index(type_key: str) -> int | None: ...
def _set_type_cls(type_info: TypeInfo, type_cls: type) -> None: ...
def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
def _lookup_type_attr(type_index: int, attr_key: str) -> Any: ...
+def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None: ...
class Error(Object):
def __init__(self, kind: str, message: str, backtrace: str) -> None: ...
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index 90679a2..6ffdc7e 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -17,12 +17,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+# N.B. Make sure `_register_object_by_index` is called in inheritance order,
+# where the base class has to be registered before the derived class.
+# Otherwise, `TypeInfo.parent_type_info` may not be properly propagated to the
derived class.
include "./base.pxi"
include "./type_info.pxi"
-include "./dtype.pxi"
-include "./device.pxi"
include "./object.pxi"
+_register_object_by_index(kTVMFFIObject, Object)
include "./error.pxi"
+_register_object_by_index(kTVMFFIError, Error)
+include "./dtype.pxi"
+_register_object_by_index(kTVMFFIDataType, DataType)
+include "./device.pxi"
+_register_object_by_index(kTVMFFIDevice, Device)
include "./string.pxi"
+_register_object_by_index(kTVMFFIStr, String)
+_register_object_by_index(kTVMFFIBytes, Bytes)
include "./tensor.pxi"
+_register_object_by_index(kTVMFFITensor, Tensor)
include "./function.pxi"
+_register_object_by_index(kTVMFFIFunction, Function)
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 413ef40..9539827 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -1,5 +1,3 @@
-
-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi
index 9ade8c2..d85cb87 100644
--- a/python/tvm_ffi/cython/error.pxi
+++ b/python/tvm_ffi/cython/error.pxi
@@ -104,9 +104,6 @@ cdef class Error(Object):
return
bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).backtrace))
-_register_object_by_index(kTVMFFIError, Error)
-
-
cdef inline Error move_from_last_error():
# raise last error
error = Error.__new__(Error)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 2db7c1a..c1fb6a2 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -1007,8 +1007,6 @@ cdef class Function(Object):
(<Object>func).chandle = chandle
return func
-_register_object_by_index(kTVMFFIFunction, Function)
-
def _register_global_func(name: str, pyfunc: Callable[..., Any] | Function,
override: bool) -> Function:
cdef TVMFFIObjectHandle chandle
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 49c7595..c8e1508 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -487,6 +487,8 @@ cdef _update_registry(int type_index, object type_key,
object type_info, object
TYPE_INDEX_TO_CLS[type_index] = type_cls
TYPE_INDEX_TO_INFO[type_index] = type_info
TYPE_KEY_TO_INFO[type_key] = type_info
+ if type_cls is not None:
+ TYPE_CLS_TO_INFO[type_cls] = type_info
def _register_object_by_index(int type_index, object type_cls):
@@ -498,12 +500,13 @@ def _register_object_by_index(int type_index, object
type_cls):
def _set_type_cls(object type_info, object type_cls):
- global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS
+ global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS, TYPE_CLS_TO_INFO
assert type_info.type_cls is None, f"Type already registered for
{type_info.type_key}"
assert TYPE_INDEX_TO_INFO[type_info.type_index] is type_info
assert TYPE_KEY_TO_INFO[type_info.type_key] is type_info
type_info.type_cls = type_cls
TYPE_INDEX_TO_CLS[type_info.type_index] = type_cls
+ TYPE_CLS_TO_INFO[type_cls] = type_info
def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo:
@@ -523,8 +526,13 @@ def _lookup_type_attr(type_index: int32_t, attr_key: str)
-> Any:
return make_ret(column.data[type_index])
+def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None:
+ return TYPE_CLS_TO_INFO.get(type_cls, None)
+
+
cdef list TYPE_INDEX_TO_CLS = []
cdef list TYPE_INDEX_TO_INFO = []
+cdef dict TYPE_CLS_TO_INFO = {}
cdef dict TYPE_KEY_TO_INFO = {}
_set_class_object(Object)
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index d7bce4d..2c23a08 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -66,9 +66,6 @@ class String(str, PyNativeObject):
return val
-_register_object_by_index(kTVMFFIStr, String)
-
-
class Bytes(bytes, PyNativeObject):
"""Byte buffer that interoperates with FFI while behaving like ``bytes``.
@@ -90,7 +87,3 @@ class Bytes(bytes, PyNativeObject):
val = bytes.__new__(cls, content)
val._tvm_ffi_cached_object = obj
return val
-
-
-_register_object_by_index(kTVMFFIBytes, Bytes)
-_register_object_by_index(kTVMFFIObject, Object)
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 0985056..e6fba34 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -330,7 +330,6 @@ cdef class Tensor(Object):
_set_class_tensor(Tensor)
-_register_object_by_index(kTVMFFITensor, Tensor)
cdef int _dltensor_test_wrapper_from_pyobject(
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 4701a74..c39b0be 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import sys
from typing import Any
import pytest
import tvm_ffi
+import tvm_ffi.testing
from tvm_ffi.core import TypeInfo
@@ -169,3 +172,59 @@ def test_unregistered_object_fallback() -> None:
for _ in range(5):
obj = tvm_ffi.testing.make_unregistered_object()
_check_type(obj)
+
+
[email protected](
+ "test_cls, type_key, parent_cls",
+ [
+ (tvm_ffi.Object, "ffi.Object", None),
+ (tvm_ffi.Tensor, "ffi.Tensor", tvm_ffi.Object),
+ (tvm_ffi.core.DataType, "DataType", None),
+ (tvm_ffi.Device, "Device", None),
+ (tvm_ffi.Shape, "ffi.Shape", tvm_ffi.Object),
+ (tvm_ffi.Module, "ffi.Module", tvm_ffi.Object),
+ (tvm_ffi.Function, "ffi.Function", tvm_ffi.Object),
+ (tvm_ffi.core.Error, "ffi.Error", tvm_ffi.Object),
+ (tvm_ffi.core.String, "ffi.String", tvm_ffi.Object),
+ (tvm_ffi.core.Bytes, "ffi.Bytes", tvm_ffi.Object),
+ (tvm_ffi.Tensor, "ffi.Tensor", tvm_ffi.Object),
+ (tvm_ffi.Array, "ffi.Array", tvm_ffi.Object),
+ (tvm_ffi.Map, "ffi.Map", tvm_ffi.Object),
+ (tvm_ffi.access_path.AccessStep, "ffi.reflection.AccessStep",
tvm_ffi.Object),
+ (tvm_ffi.access_path.AccessPath, "ffi.reflection.AccessPath",
tvm_ffi.Object),
+ (tvm_ffi.testing.TestIntPair, "testing.TestIntPair", tvm_ffi.Object),
+ (tvm_ffi.testing.TestObjectBase, "testing.TestObjectBase",
tvm_ffi.Object),
+ (
+ tvm_ffi.testing.TestObjectDerived,
+ "testing.TestObjectDerived",
+ tvm_ffi.testing.TestObjectBase,
+ ),
+ (tvm_ffi.testing._TestCxxClassBase, "testing.TestCxxClassBase",
tvm_ffi.Object),
+ (
+ tvm_ffi.testing._TestCxxClassDerived,
+ "testing.TestCxxClassDerived",
+ tvm_ffi.testing._TestCxxClassBase,
+ ),
+ (
+ tvm_ffi.testing._TestCxxClassDerivedDerived,
+ "testing.TestCxxClassDerivedDerived",
+ tvm_ffi.testing._TestCxxClassDerived,
+ ),
+ (tvm_ffi.testing._TestCxxInitSubset, "testing.TestCxxInitSubset",
tvm_ffi.Object),
+ (tvm_ffi.testing._SchemaAllTypes, "testing.SchemaAllTypes",
tvm_ffi.Object),
+ ],
+)
+def test_type_info_attachment(test_cls: type, type_key: str, parent_cls: type
| None) -> None:
+ type_info = tvm_ffi.core._type_cls_to_type_info(test_cls)
+ assert type_info is not None
+ assert type_info.type_cls is test_cls
+ assert type_info.type_key == type_key, (
+ f"Expected type key `{type_key}` but got `{type_info.type_key}`"
+ )
+ parent_type_info = type_info.parent_type_info
+ if parent_type_info is None:
+ assert parent_cls is None
+ else:
+ assert parent_type_info.type_cls is parent_cls, (
+ f"Expected parent type {parent_cls}, but got
{parent_type_info.type_cls}"
+ )