This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 6897a5f  fix(cython): Make sure `TypeInfo` is properly registered by 
all classes (#246)
6897a5f is described below

commit 6897a5f5a8dc662270dd10be40cf261bd1da93f8
Author: Junru Shao <[email protected]>
AuthorDate: Sat Nov 8 17:54:28 2025 -0800

    fix(cython): Make sure `TypeInfo` is properly registered by all classes 
(#246)
    
    Before this commit, `parent_type_info` of a few Cython classes is not
    properly propagated, because of the registration ordering. This PR fixes
    this behavior and adds tests to safeguard this case.
---
 python/tvm_ffi/core.pyi            |  2 +-
 python/tvm_ffi/cython/core.pyx     | 16 +++++++++--
 python/tvm_ffi/cython/device.pxi   |  2 --
 python/tvm_ffi/cython/error.pxi    |  3 --
 python/tvm_ffi/cython/function.pxi |  2 --
 python/tvm_ffi/cython/object.pxi   | 10 ++++++-
 python/tvm_ffi/cython/string.pxi   |  7 -----
 python/tvm_ffi/cython/tensor.pxi   |  1 -
 tests/python/test_object.py        | 59 ++++++++++++++++++++++++++++++++++++++
 9 files changed, 83 insertions(+), 19 deletions(-)

diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 8141748..261803c 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -62,12 +62,12 @@ class PyNativeObject:
     __slots__: list[str]
     def __init_cached_object_by_constructor__(self, fconstructor: Any, *args: 
Any) -> None: ...
 
-def _set_class_object(cls: type) -> None: ...
 def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
 def _object_type_key_to_index(type_key: str) -> int | None: ...
 def _set_type_cls(type_info: TypeInfo, type_cls: type) -> None: ...
 def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
 def _lookup_type_attr(type_index: int, attr_key: str) -> Any: ...
+def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None: ...
 
 class Error(Object):
     def __init__(self, kind: str, message: str, backtrace: str) -> None: ...
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index 90679a2..6ffdc7e 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -17,12 +17,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+# N.B. Make sure `_register_object_by_index` is called in inheritance order,
+# where the base class has to be registered before the derived class.
+# Otherwise, `TypeInfo.parent_type_info` may not be properly propagated to the 
derived class.
 include "./base.pxi"
 include "./type_info.pxi"
-include "./dtype.pxi"
-include "./device.pxi"
 include "./object.pxi"
+_register_object_by_index(kTVMFFIObject, Object)
 include "./error.pxi"
+_register_object_by_index(kTVMFFIError, Error)
+include "./dtype.pxi"
+_register_object_by_index(kTVMFFIDataType, DataType)
+include "./device.pxi"
+_register_object_by_index(kTVMFFIDevice, Device)
 include "./string.pxi"
+_register_object_by_index(kTVMFFIStr, String)
+_register_object_by_index(kTVMFFIBytes, Bytes)
 include "./tensor.pxi"
+_register_object_by_index(kTVMFFITensor, Tensor)
 include "./function.pxi"
+_register_object_by_index(kTVMFFIFunction, Function)
diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi
index 413ef40..9539827 100644
--- a/python/tvm_ffi/cython/device.pxi
+++ b/python/tvm_ffi/cython/device.pxi
@@ -1,5 +1,3 @@
-
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi
index 9ade8c2..d85cb87 100644
--- a/python/tvm_ffi/cython/error.pxi
+++ b/python/tvm_ffi/cython/error.pxi
@@ -104,9 +104,6 @@ cdef class Error(Object):
         return 
bytearray_to_str(&(TVMFFIErrorGetCellPtr(self.chandle).backtrace))
 
 
-_register_object_by_index(kTVMFFIError, Error)
-
-
 cdef inline Error move_from_last_error():
     # raise last error
     error = Error.__new__(Error)
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 2db7c1a..c1fb6a2 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -1007,8 +1007,6 @@ cdef class Function(Object):
         (<Object>func).chandle = chandle
         return func
 
-_register_object_by_index(kTVMFFIFunction, Function)
-
 
 def _register_global_func(name: str, pyfunc: Callable[..., Any] | Function, 
override: bool) -> Function:
     cdef TVMFFIObjectHandle chandle
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 49c7595..c8e1508 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -487,6 +487,8 @@ cdef _update_registry(int type_index, object type_key, 
object type_info, object
     TYPE_INDEX_TO_CLS[type_index] = type_cls
     TYPE_INDEX_TO_INFO[type_index] = type_info
     TYPE_KEY_TO_INFO[type_key] = type_info
+    if type_cls is not None:
+        TYPE_CLS_TO_INFO[type_cls] = type_info
 
 
 def _register_object_by_index(int type_index, object type_cls):
@@ -498,12 +500,13 @@ def _register_object_by_index(int type_index, object 
type_cls):
 
 
 def _set_type_cls(object type_info, object type_cls):
-    global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS
+    global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS, TYPE_CLS_TO_INFO
     assert type_info.type_cls is None, f"Type already registered for 
{type_info.type_key}"
     assert TYPE_INDEX_TO_INFO[type_info.type_index] is type_info
     assert TYPE_KEY_TO_INFO[type_info.type_key] is type_info
     type_info.type_cls = type_cls
     TYPE_INDEX_TO_CLS[type_info.type_index] = type_cls
+    TYPE_CLS_TO_INFO[type_cls] = type_info
 
 
 def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo:
@@ -523,8 +526,13 @@ def _lookup_type_attr(type_index: int32_t, attr_key: str) 
-> Any:
     return make_ret(column.data[type_index])
 
 
+def _type_cls_to_type_info(type_cls: type) -> TypeInfo | None:
+    return TYPE_CLS_TO_INFO.get(type_cls, None)
+
+
 cdef list TYPE_INDEX_TO_CLS = []
 cdef list TYPE_INDEX_TO_INFO = []
+cdef dict TYPE_CLS_TO_INFO = {}
 cdef dict TYPE_KEY_TO_INFO = {}
 
 _set_class_object(Object)
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index d7bce4d..2c23a08 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -66,9 +66,6 @@ class String(str, PyNativeObject):
         return val
 
 
-_register_object_by_index(kTVMFFIStr, String)
-
-
 class Bytes(bytes, PyNativeObject):
     """Byte buffer that interoperates with FFI while behaving like ``bytes``.
 
@@ -90,7 +87,3 @@ class Bytes(bytes, PyNativeObject):
         val = bytes.__new__(cls, content)
         val._tvm_ffi_cached_object = obj
         return val
-
-
-_register_object_by_index(kTVMFFIBytes, Bytes)
-_register_object_by_index(kTVMFFIObject, Object)
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 0985056..e6fba34 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -330,7 +330,6 @@ cdef class Tensor(Object):
 
 
 _set_class_tensor(Tensor)
-_register_object_by_index(kTVMFFITensor, Tensor)
 
 
 cdef int _dltensor_test_wrapper_from_pyobject(
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 4701a74..c39b0be 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -14,11 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from __future__ import annotations
+
 import sys
 from typing import Any
 
 import pytest
 import tvm_ffi
+import tvm_ffi.testing
 from tvm_ffi.core import TypeInfo
 
 
@@ -169,3 +172,59 @@ def test_unregistered_object_fallback() -> None:
     for _ in range(5):
         obj = tvm_ffi.testing.make_unregistered_object()
         _check_type(obj)
+
+
[email protected](
+    "test_cls, type_key, parent_cls",
+    [
+        (tvm_ffi.Object, "ffi.Object", None),
+        (tvm_ffi.Tensor, "ffi.Tensor", tvm_ffi.Object),
+        (tvm_ffi.core.DataType, "DataType", None),
+        (tvm_ffi.Device, "Device", None),
+        (tvm_ffi.Shape, "ffi.Shape", tvm_ffi.Object),
+        (tvm_ffi.Module, "ffi.Module", tvm_ffi.Object),
+        (tvm_ffi.Function, "ffi.Function", tvm_ffi.Object),
+        (tvm_ffi.core.Error, "ffi.Error", tvm_ffi.Object),
+        (tvm_ffi.core.String, "ffi.String", tvm_ffi.Object),
+        (tvm_ffi.core.Bytes, "ffi.Bytes", tvm_ffi.Object),
+        (tvm_ffi.Tensor, "ffi.Tensor", tvm_ffi.Object),
+        (tvm_ffi.Array, "ffi.Array", tvm_ffi.Object),
+        (tvm_ffi.Map, "ffi.Map", tvm_ffi.Object),
+        (tvm_ffi.access_path.AccessStep, "ffi.reflection.AccessStep", 
tvm_ffi.Object),
+        (tvm_ffi.access_path.AccessPath, "ffi.reflection.AccessPath", 
tvm_ffi.Object),
+        (tvm_ffi.testing.TestIntPair, "testing.TestIntPair", tvm_ffi.Object),
+        (tvm_ffi.testing.TestObjectBase, "testing.TestObjectBase", 
tvm_ffi.Object),
+        (
+            tvm_ffi.testing.TestObjectDerived,
+            "testing.TestObjectDerived",
+            tvm_ffi.testing.TestObjectBase,
+        ),
+        (tvm_ffi.testing._TestCxxClassBase, "testing.TestCxxClassBase", 
tvm_ffi.Object),
+        (
+            tvm_ffi.testing._TestCxxClassDerived,
+            "testing.TestCxxClassDerived",
+            tvm_ffi.testing._TestCxxClassBase,
+        ),
+        (
+            tvm_ffi.testing._TestCxxClassDerivedDerived,
+            "testing.TestCxxClassDerivedDerived",
+            tvm_ffi.testing._TestCxxClassDerived,
+        ),
+        (tvm_ffi.testing._TestCxxInitSubset, "testing.TestCxxInitSubset", 
tvm_ffi.Object),
+        (tvm_ffi.testing._SchemaAllTypes, "testing.SchemaAllTypes", 
tvm_ffi.Object),
+    ],
+)
+def test_type_info_attachment(test_cls: type, type_key: str, parent_cls: type 
| None) -> None:
+    type_info = tvm_ffi.core._type_cls_to_type_info(test_cls)
+    assert type_info is not None
+    assert type_info.type_cls is test_cls
+    assert type_info.type_key == type_key, (
+        f"Expected type key `{type_key}` but got `{type_info.type_key}`"
+    )
+    parent_type_info = type_info.parent_type_info
+    if parent_type_info is None:
+        assert parent_cls is None
+    else:
+        assert parent_type_info.type_cls is parent_cls, (
+            f"Expected parent type {parent_cls}, but got 
{parent_type_info.type_cls}"
+        )

Reply via email to