This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 53b2e00  feat: Introduce Python-side TypeInfo Metadata (#30)
53b2e00 is described below

commit 53b2e00ef90a34f2dfa79014877dc6ca53e78c0f
Author: Junru Shao <[email protected]>
AuthorDate: Fri Sep 19 11:15:35 2025 -0700

    feat: Introduce Python-side TypeInfo Metadata (#30)
    
    Previously, `TypeInfo` stays strictly in C and is glued to Python via
    Cython. This PR makes it Cython side registry to store Python objects,
    i.e. `TypeInfo`, which makes it possible to lookup type info in pure
    Python.
    
    This PR is split from #8.
---
 python/tvm_ffi/core.pyi             |  39 ++++++++++++-
 python/tvm_ffi/cython/core.pyx      |   3 +-
 python/tvm_ffi/cython/function.pxi  | 103 ---------------------------------
 python/tvm_ffi/cython/object.pxi    | 100 +++++++++++++++++++++++++++-----
 python/tvm_ffi/cython/type_info.pxi | 112 ++++++++++++++++++++++++++++++++++++
 python/tvm_ffi/registry.py          |  45 ++++++++++++++-
 6 files changed, 278 insertions(+), 124 deletions(-)

diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index d57d020..afadcd1 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -74,9 +74,9 @@ class PyNativeObject:
     ) -> None: ...
 
 def _set_class_object(cls: type) -> None: ...
-def _register_object_by_index(index: int, cls: type) -> None: ...
+def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
 def _object_type_key_to_index(type_key: str) -> int | None: ...
-def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ...
+def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ...
 
 class Error(Object):
     """Base class for FFI errors."""
@@ -225,3 +225,38 @@ class Bytes(bytes, PyNativeObject):
 
     # pylint: disable=no-self-argument
     def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...
+
+# ---------------------------------------------------------------------------
+# Type reflection metadata (from cython/type_info.pxi)
+# ---------------------------------------------------------------------------
+
+class TypeField:
+    """Description of a single reflected field on an FFI-backed type."""
+
+    name: str
+    doc: str | None
+    size: int
+    offset: int
+    frozen: bool
+    getter: Any
+    setter: Any
+
+    def as_property(self, cls: type) -> property: ...
+
+class TypeMethod:
+    """Description of a single reflected method on an FFI-backed type."""
+
+    name: str
+    doc: str | None
+    func: Any
+    is_static: bool
+
+class TypeInfo:
+    """Aggregated type information required to build a proxy class."""
+
+    type_cls: type | None
+    type_index: int
+    type_key: str
+    fields: list[TypeField]
+    methods: list[TypeMethod]
+    parent_type_info: TypeInfo | None
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index b24a83d..ca3a0ce 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -14,9 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-
 include "./base.pxi"
+include "./type_info.pxi"
 include "./dtype.pxi"
 include "./device.pxi"
 include "./object.pxi"
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index c4662ca..f4503ec 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -620,109 +620,6 @@ cdef class Function(Object):
 _register_object_by_index(kTVMFFIFunction, Function)
 
 
-cdef class FieldGetter:
-    cdef TVMFFIFieldGetter getter
-    cdef int64_t offset
-
-    def __call__(self, Object obj):
-        cdef TVMFFIAny result
-        cdef int c_api_ret_code
-        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
-        result.type_index = kTVMFFINone
-        result.v_int64 = 0
-        c_api_ret_code = self.getter(field_ptr, &result)
-        CHECK_CALL(c_api_ret_code)
-        return make_ret(result)
-
-
-cdef class FieldSetter:
-    cdef TVMFFIFieldSetter setter
-    cdef int64_t offset
-
-    def __call__(self, Object obj, value):
-        cdef int c_api_ret_code
-        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
-        TVMFFIPyCallFieldSetter(
-            TVMFFIPyArgSetterFactory_,
-            self.setter,
-            field_ptr,
-            <PyObject*>value,
-            &c_api_ret_code
-        )
-        # NOTE: logic is same as check_call
-        # directly inline here to simplify traceback
-        if c_api_ret_code == 0:
-            return
-        elif c_api_ret_code == -2:
-            raise_existing_error()
-        raise move_from_last_error().py_error()
-
-
-cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
-    cdef TVMFFIAny result
-    CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
-    return make_ret(result)
-
-
-def _member_method_wrapper(method_func):
-    def wrapper(self, *args):
-        return method_func(self, *args)
-    return wrapper
-
-
-def _add_class_attrs_by_reflection(int type_index, object cls):
-    """Decorate the class attrs by reflection"""
-    cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index)
-    cdef const TVMFFIFieldInfo* field
-    cdef const TVMFFIMethodInfo* method
-    cdef int num_fields = info.num_fields
-    cdef int num_methods = info.num_methods
-
-    for i in range(num_fields):
-        # attach fields to the class
-        field = &(info.fields[i])
-        getter = FieldGetter.__new__(FieldGetter)
-        (<FieldGetter>getter).getter = field.getter
-        (<FieldGetter>getter).offset = field.offset
-        setter = FieldSetter.__new__(FieldSetter)
-        (<FieldSetter>setter).setter = field.setter
-        (<FieldSetter>setter).offset = field.offset
-        if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0:
-            setter = None
-        doc = bytearray_to_str(&field.doc) if field.doc.size != 0 else None
-        name = bytearray_to_str(&field.name)
-        if hasattr(cls, name):
-            # skip already defined attributes
-            continue
-        setattr(cls, name, property(getter, setter, doc=doc))
-
-    for i in range(num_methods):
-        # attach methods to the class
-        method = &(info.methods[i])
-        name = bytearray_to_str(&method.name)
-        doc = bytearray_to_str(&method.doc) if method.doc.size != 0 else None
-        method_func = _get_method_from_method_info(method)
-
-        if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod:
-            method_pyfunc = staticmethod(method_func)
-        else:
-            # must call into another method instead of direct capture
-            # to avoid the same method_func variable being used
-            # across multiple loop iterations
-            method_pyfunc = _member_method_wrapper(method_func)
-
-        if doc is not None:
-            method_pyfunc.__doc__ = doc
-        method_pyfunc.__name__ = name
-
-        if hasattr(cls, name):
-            # skip already defined attributes
-            continue
-        setattr(cls, name, method_pyfunc)
-
-    return cls
-
-
 def _register_global_func(name, pyfunc, override):
     cdef TVMFFIObjectHandle chandle
     cdef int c_api_ret_code
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 08cfd54..3d0e33e 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -227,18 +227,6 @@ class PyNativeObject:
         self.__tvm_ffi_object__ = obj
 
 
-"""Maps object type index to its constructor"""
-cdef list OBJECT_TYPE = []
-
-
-def _register_object_by_index(int index, object cls):
-    """register object class"""
-    global OBJECT_TYPE
-    while len(OBJECT_TYPE) <= index:
-        OBJECT_TYPE.append(None)
-    OBJECT_TYPE[index] = cls
-
-
 def _object_type_key_to_index(str type_key):
     """get the type index of object class"""
     cdef int32_t tidx
@@ -265,13 +253,13 @@ cdef inline object make_ret_opaque_object(TVMFFIAny 
result):
 
 
 cdef inline object make_ret_object(TVMFFIAny result):
-    global OBJECT_TYPE
+    global TYPE_INDEX_TO_INFO
     cdef int32_t tindex
     cdef object cls
     tindex = result.type_index
 
-    if tindex < len(OBJECT_TYPE):
-        cls = OBJECT_TYPE[tindex]
+    if tindex < len(TYPE_INDEX_TO_INFO):
+        cls = TYPE_INDEX_TO_INFO[tindex].type_cls
         if cls is not None:
             if issubclass(cls, PyNativeObject):
                 obj = Object.__new__(Object)
@@ -290,4 +278,86 @@ cdef inline object make_ret_object(TVMFFIAny result):
     return obj
 
 
+cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
+    cdef TVMFFIAny result
+    CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
+    return make_ret(result)
+
+
+def _type_info_create_from_type_key(object type_cls, str type_key):
+    cdef const TVMFFIFieldInfo* field
+    cdef const TVMFFIMethodInfo* method
+    cdef const TVMFFITypeInfo* info
+    cdef int32_t type_index
+    cdef object fields = []
+    cdef object methods = []
+    cdef FieldGetter getter
+    cdef FieldSetter setter
+
+    if TVMFFITypeKeyToIndex(ByteArrayArg(c_str(type_key)).cptr(), &type_index) 
!= 0:
+        raise ValueError(f"Cannot find type key: {type_key}")
+    info = TVMFFIGetTypeInfo(type_index)
+    for i in range(info.num_fields):
+        field = &(info.fields[i])
+        getter = FieldGetter.__new__(FieldGetter)
+        (<FieldGetter>getter).getter = field.getter
+        (<FieldGetter>getter).offset = field.offset
+        setter = FieldSetter.__new__(FieldSetter)
+        (<FieldSetter>setter).setter = field.setter
+        (<FieldSetter>setter).offset = field.offset
+        fields.append(
+            TypeField(
+                name=bytearray_to_str(&field.name),
+                doc=bytearray_to_str(&field.doc) if field.doc.size != 0 else 
None,
+                size=field.size,
+                offset=field.offset,
+                frozen=(field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0,
+                getter=getter,
+                setter=setter,
+            )
+        )
+
+    for i in range(info.num_methods):
+        method = &(info.methods[i])
+        methods.append(
+            TypeMethod(
+                name=bytearray_to_str(&method.name),
+                doc=bytearray_to_str(&method.doc) if method.doc.size != 0 else 
None,
+                func=_get_method_from_method_info(method),
+                is_static=(method.flags & 
kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0,
+            )
+        )
+
+    return TypeInfo(
+        type_cls=type_cls,
+        type_index=type_index,
+        type_key=bytearray_to_str(&info.type_key),
+        fields=fields,
+        methods=methods,
+        parent_type_info=None,
+    )
+
+
+def _register_object_by_index(int type_index, object type_cls):
+    global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO
+    cdef str type_key = _type_index_to_key(type_index)
+    cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+    if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0:
+        TYPE_INDEX_TO_INFO.extend([None] * extra)
+    TYPE_INDEX_TO_INFO[type_index] = info
+    TYPE_KEY_TO_INFO[type_key] = info
+    return info
+
+
+def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo:
+    if info := TYPE_KEY_TO_INFO.get(type_key, None):
+        return info
+    info = _type_info_create_from_type_key(None, type_key)
+    TYPE_KEY_TO_INFO[type_key] = info
+    return info
+
+
+cdef list TYPE_INDEX_TO_INFO = []
+cdef dict TYPE_KEY_TO_INFO = {}
+
 _set_class_object(Object)
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
new file mode 100644
index 0000000..2abb204
--- /dev/null
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import dataclasses
+
+
+cdef class FieldGetter:
+    cdef dict __dict__
+    cdef TVMFFIFieldGetter getter
+    cdef int64_t offset
+
+    def __call__(self, Object obj):
+        cdef TVMFFIAny result
+        cdef int c_api_ret_code
+        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+        result.type_index = kTVMFFINone
+        result.v_int64 = 0
+        c_api_ret_code = self.getter(field_ptr, &result)
+        CHECK_CALL(c_api_ret_code)
+        return make_ret(result)
+
+
+cdef class FieldSetter:
+    cdef dict __dict__
+    cdef TVMFFIFieldSetter setter
+    cdef int64_t offset
+
+    def __call__(self, Object obj, value):
+        cdef int c_api_ret_code
+        cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+        TVMFFIPyCallFieldSetter(
+            TVMFFIPyArgSetterFactory_,
+            self.setter,
+            field_ptr,
+            <PyObject*>value,
+            &c_api_ret_code
+        )
+        # NOTE: logic is same as check_call
+        # directly inline here to simplify traceback
+        if c_api_ret_code == 0:
+            return
+        elif c_api_ret_code == -2:
+            raise_existing_error()
+        raise move_from_last_error().py_error()
+
+
[email protected](eq=False)
+class TypeField:
+    """Description of a single reflected field on an FFI-backed type."""
+
+    name: str
+    doc: str | None
+    size: int
+    offset: int
+    frozen: bool
+    getter: FieldGetter
+    setter: FieldSetter
+
+    def __post_init__(self):
+        assert self.setter is not None
+        assert self.getter is not None
+
+    def as_property(self, cls: type) -> property:
+        """Create a Python ``property`` object for this field on ``cls``."""
+        name = self.name
+        fget = self.getter
+        fset = self.setter
+        fget.__name__ = fset.__name__ = name
+        fget.__module__ = fset.__module__ = cls.__module__
+        fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"  
# type: ignore[attr-defined]
+        fget.__doc__ = fset.__doc__ = f"Property `{name}` of class 
`{cls.__qualname__}`"  # type: ignore[attr-defined]
+
+        return property(
+            fget=fget if self.getter is not None else None,
+            fset=fset if (not self.frozen) and self.setter is not None else 
None,
+            doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
+        )
+
+
[email protected](eq=False)
+class TypeMethod:
+    """Description of a single reflected method on an FFI-backed type."""
+
+    name: str
+    doc: str | None
+    func: object
+    is_static: bool
+
+
[email protected](eq=False)
+class TypeInfo:
+    """Aggregated type information required to build a proxy class."""
+
+    type_cls: type | None
+    type_index: int
+    type_key: str
+    fields: list[TypeField]
+    methods: list[TypeMethod]
+    parent_type_info: TypeInfo | None
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 1f4e340..5a540fb 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -20,6 +20,7 @@ import sys
 from typing import Any, Callable, Optional
 
 from . import core
+from .core import TypeInfo
 
 # whether we simplify skip unknown objects regtistration
 _SKIP_UNKNOWN_OBJECTS = False
@@ -54,8 +55,8 @@ def register_object(type_key: str | type | None = None) -> 
Any:
             if _SKIP_UNKNOWN_OBJECTS:
                 return cls
             raise ValueError(f"Cannot find object type index for 
{object_name}")
-        core._add_class_attrs_by_reflection(type_index, cls)
-        core._register_object_by_index(type_index, cls)
+        info = core._register_object_by_index(type_index, cls)
+        _add_class_attrs(type_cls=cls, type_info=info)
         return cls
 
     if isinstance(type_key, str):
@@ -228,6 +229,46 @@ def init_ffi_api(namespace: str, target_module_name: 
Optional[str] = None) -> No
         setattr(target_module, f.__name__, f)
 
 
+def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
+    def wrapper(self: Any, *args: Any) -> Any:
+        return method_func(self, *args)
+
+    return wrapper
+
+
+def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
+    for field in type_info.fields:
+        getter = field.getter
+        setter = field.setter if not field.frozen else None
+        doc = field.doc if field.doc else None
+        name = field.name
+        if hasattr(type_cls, name):
+            # skip already defined attributes
+            continue
+        setattr(type_cls, name, property(getter, setter, doc=doc))
+    for method in type_info.methods:
+        name = method.name
+        doc = method.doc if method.doc else None
+        method_func = method.func
+        if method.is_static:
+            method_pyfunc = staticmethod(method_func)
+        else:
+            # must call into another method instead of direct capture
+            # to avoid the same method_func variable being used
+            # across multiple loop iterations
+            method_pyfunc = _member_method_wrapper(method_func)
+
+        if doc is not None:
+            method_pyfunc.__doc__ = doc
+        method_pyfunc.__name__ = name
+
+        if hasattr(type_cls, name):
+            # skip already defined attributes
+            continue
+        setattr(type_cls, name, method_pyfunc)
+    return type_cls
+
+
 __all__ = [
     "get_global_func",
     "init_ffi_api",

Reply via email to