This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 53b2e00 feat: Introduce Python-side TypeInfo Metadata (#30)
53b2e00 is described below
commit 53b2e00ef90a34f2dfa79014877dc6ca53e78c0f
Author: Junru Shao <[email protected]>
AuthorDate: Fri Sep 19 11:15:35 2025 -0700
feat: Introduce Python-side TypeInfo Metadata (#30)
Previously, `TypeInfo` stays strictly in C and is glued to Python via
Cython. This PR makes it Cython side registry to store Python objects,
i.e. `TypeInfo`, which makes it possible to lookup type info in pure
Python.
This PR is split from #8.
---
python/tvm_ffi/core.pyi | 39 ++++++++++++-
python/tvm_ffi/cython/core.pyx | 3 +-
python/tvm_ffi/cython/function.pxi | 103 ---------------------------------
python/tvm_ffi/cython/object.pxi | 100 +++++++++++++++++++++++++++-----
python/tvm_ffi/cython/type_info.pxi | 112 ++++++++++++++++++++++++++++++++++++
python/tvm_ffi/registry.py | 45 ++++++++++++++-
6 files changed, 278 insertions(+), 124 deletions(-)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index d57d020..afadcd1 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -74,9 +74,9 @@ class PyNativeObject:
) -> None: ...
def _set_class_object(cls: type) -> None: ...
-def _register_object_by_index(index: int, cls: type) -> None: ...
+def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
def _object_type_key_to_index(type_key: str) -> int | None: ...
-def _add_class_attrs_by_reflection(type_index: int, cls: type) -> type: ...
+def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ...
class Error(Object):
"""Base class for FFI errors."""
@@ -225,3 +225,38 @@ class Bytes(bytes, PyNativeObject):
# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> Bytes: ...
+
+# ---------------------------------------------------------------------------
+# Type reflection metadata (from cython/type_info.pxi)
+# ---------------------------------------------------------------------------
+
+class TypeField:
+ """Description of a single reflected field on an FFI-backed type."""
+
+ name: str
+ doc: str | None
+ size: int
+ offset: int
+ frozen: bool
+ getter: Any
+ setter: Any
+
+ def as_property(self, cls: type) -> property: ...
+
+class TypeMethod:
+ """Description of a single reflected method on an FFI-backed type."""
+
+ name: str
+ doc: str | None
+ func: Any
+ is_static: bool
+
+class TypeInfo:
+ """Aggregated type information required to build a proxy class."""
+
+ type_cls: type | None
+ type_index: int
+ type_key: str
+ fields: list[TypeField]
+ methods: list[TypeMethod]
+ parent_type_info: TypeInfo | None
diff --git a/python/tvm_ffi/cython/core.pyx b/python/tvm_ffi/cython/core.pyx
index b24a83d..ca3a0ce 100644
--- a/python/tvm_ffi/cython/core.pyx
+++ b/python/tvm_ffi/cython/core.pyx
@@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
include "./base.pxi"
+include "./type_info.pxi"
include "./dtype.pxi"
include "./device.pxi"
include "./object.pxi"
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index c4662ca..f4503ec 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -620,109 +620,6 @@ cdef class Function(Object):
_register_object_by_index(kTVMFFIFunction, Function)
-cdef class FieldGetter:
- cdef TVMFFIFieldGetter getter
- cdef int64_t offset
-
- def __call__(self, Object obj):
- cdef TVMFFIAny result
- cdef int c_api_ret_code
- cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
- result.type_index = kTVMFFINone
- result.v_int64 = 0
- c_api_ret_code = self.getter(field_ptr, &result)
- CHECK_CALL(c_api_ret_code)
- return make_ret(result)
-
-
-cdef class FieldSetter:
- cdef TVMFFIFieldSetter setter
- cdef int64_t offset
-
- def __call__(self, Object obj, value):
- cdef int c_api_ret_code
- cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
- TVMFFIPyCallFieldSetter(
- TVMFFIPyArgSetterFactory_,
- self.setter,
- field_ptr,
- <PyObject*>value,
- &c_api_ret_code
- )
- # NOTE: logic is same as check_call
- # directly inline here to simplify traceback
- if c_api_ret_code == 0:
- return
- elif c_api_ret_code == -2:
- raise_existing_error()
- raise move_from_last_error().py_error()
-
-
-cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
- cdef TVMFFIAny result
- CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
- return make_ret(result)
-
-
-def _member_method_wrapper(method_func):
- def wrapper(self, *args):
- return method_func(self, *args)
- return wrapper
-
-
-def _add_class_attrs_by_reflection(int type_index, object cls):
- """Decorate the class attrs by reflection"""
- cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(type_index)
- cdef const TVMFFIFieldInfo* field
- cdef const TVMFFIMethodInfo* method
- cdef int num_fields = info.num_fields
- cdef int num_methods = info.num_methods
-
- for i in range(num_fields):
- # attach fields to the class
- field = &(info.fields[i])
- getter = FieldGetter.__new__(FieldGetter)
- (<FieldGetter>getter).getter = field.getter
- (<FieldGetter>getter).offset = field.offset
- setter = FieldSetter.__new__(FieldSetter)
- (<FieldSetter>setter).setter = field.setter
- (<FieldSetter>setter).offset = field.offset
- if (field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0:
- setter = None
- doc = bytearray_to_str(&field.doc) if field.doc.size != 0 else None
- name = bytearray_to_str(&field.name)
- if hasattr(cls, name):
- # skip already defined attributes
- continue
- setattr(cls, name, property(getter, setter, doc=doc))
-
- for i in range(num_methods):
- # attach methods to the class
- method = &(info.methods[i])
- name = bytearray_to_str(&method.name)
- doc = bytearray_to_str(&method.doc) if method.doc.size != 0 else None
- method_func = _get_method_from_method_info(method)
-
- if method.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod:
- method_pyfunc = staticmethod(method_func)
- else:
- # must call into another method instead of direct capture
- # to avoid the same method_func variable being used
- # across multiple loop iterations
- method_pyfunc = _member_method_wrapper(method_func)
-
- if doc is not None:
- method_pyfunc.__doc__ = doc
- method_pyfunc.__name__ = name
-
- if hasattr(cls, name):
- # skip already defined attributes
- continue
- setattr(cls, name, method_pyfunc)
-
- return cls
-
-
def _register_global_func(name, pyfunc, override):
cdef TVMFFIObjectHandle chandle
cdef int c_api_ret_code
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 08cfd54..3d0e33e 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -227,18 +227,6 @@ class PyNativeObject:
self.__tvm_ffi_object__ = obj
-"""Maps object type index to its constructor"""
-cdef list OBJECT_TYPE = []
-
-
-def _register_object_by_index(int index, object cls):
- """register object class"""
- global OBJECT_TYPE
- while len(OBJECT_TYPE) <= index:
- OBJECT_TYPE.append(None)
- OBJECT_TYPE[index] = cls
-
-
def _object_type_key_to_index(str type_key):
"""get the type index of object class"""
cdef int32_t tidx
@@ -265,13 +253,13 @@ cdef inline object make_ret_opaque_object(TVMFFIAny
result):
cdef inline object make_ret_object(TVMFFIAny result):
- global OBJECT_TYPE
+ global TYPE_INDEX_TO_INFO
cdef int32_t tindex
cdef object cls
tindex = result.type_index
- if tindex < len(OBJECT_TYPE):
- cls = OBJECT_TYPE[tindex]
+ if tindex < len(TYPE_INDEX_TO_INFO):
+ cls = TYPE_INDEX_TO_INFO[tindex].type_cls
if cls is not None:
if issubclass(cls, PyNativeObject):
obj = Object.__new__(Object)
@@ -290,4 +278,86 @@ cdef inline object make_ret_object(TVMFFIAny result):
return obj
+cdef _get_method_from_method_info(const TVMFFIMethodInfo* method):
+ cdef TVMFFIAny result
+ CHECK_CALL(TVMFFIAnyViewToOwnedAny(&(method.method), &result))
+ return make_ret(result)
+
+
+def _type_info_create_from_type_key(object type_cls, str type_key):
+ cdef const TVMFFIFieldInfo* field
+ cdef const TVMFFIMethodInfo* method
+ cdef const TVMFFITypeInfo* info
+ cdef int32_t type_index
+ cdef object fields = []
+ cdef object methods = []
+ cdef FieldGetter getter
+ cdef FieldSetter setter
+
+ if TVMFFITypeKeyToIndex(ByteArrayArg(c_str(type_key)).cptr(), &type_index)
!= 0:
+ raise ValueError(f"Cannot find type key: {type_key}")
+ info = TVMFFIGetTypeInfo(type_index)
+ for i in range(info.num_fields):
+ field = &(info.fields[i])
+ getter = FieldGetter.__new__(FieldGetter)
+ (<FieldGetter>getter).getter = field.getter
+ (<FieldGetter>getter).offset = field.offset
+ setter = FieldSetter.__new__(FieldSetter)
+ (<FieldSetter>setter).setter = field.setter
+ (<FieldSetter>setter).offset = field.offset
+ fields.append(
+ TypeField(
+ name=bytearray_to_str(&field.name),
+ doc=bytearray_to_str(&field.doc) if field.doc.size != 0 else
None,
+ size=field.size,
+ offset=field.offset,
+ frozen=(field.flags & kTVMFFIFieldFlagBitMaskWritable) == 0,
+ getter=getter,
+ setter=setter,
+ )
+ )
+
+ for i in range(info.num_methods):
+ method = &(info.methods[i])
+ methods.append(
+ TypeMethod(
+ name=bytearray_to_str(&method.name),
+ doc=bytearray_to_str(&method.doc) if method.doc.size != 0 else
None,
+ func=_get_method_from_method_info(method),
+ is_static=(method.flags &
kTVMFFIFieldFlagBitMaskIsStaticMethod) != 0,
+ )
+ )
+
+ return TypeInfo(
+ type_cls=type_cls,
+ type_index=type_index,
+ type_key=bytearray_to_str(&info.type_key),
+ fields=fields,
+ methods=methods,
+ parent_type_info=None,
+ )
+
+
+def _register_object_by_index(int type_index, object type_cls):
+ global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO
+ cdef str type_key = _type_index_to_key(type_index)
+ cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+ if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0:
+ TYPE_INDEX_TO_INFO.extend([None] * extra)
+ TYPE_INDEX_TO_INFO[type_index] = info
+ TYPE_KEY_TO_INFO[type_key] = info
+ return info
+
+
+def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo:
+ if info := TYPE_KEY_TO_INFO.get(type_key, None):
+ return info
+ info = _type_info_create_from_type_key(None, type_key)
+ TYPE_KEY_TO_INFO[type_key] = info
+ return info
+
+
+cdef list TYPE_INDEX_TO_INFO = []
+cdef dict TYPE_KEY_TO_INFO = {}
+
_set_class_object(Object)
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
new file mode 100644
index 0000000..2abb204
--- /dev/null
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import dataclasses
+
+
+cdef class FieldGetter:
+ cdef dict __dict__
+ cdef TVMFFIFieldGetter getter
+ cdef int64_t offset
+
+ def __call__(self, Object obj):
+ cdef TVMFFIAny result
+ cdef int c_api_ret_code
+ cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+ result.type_index = kTVMFFINone
+ result.v_int64 = 0
+ c_api_ret_code = self.getter(field_ptr, &result)
+ CHECK_CALL(c_api_ret_code)
+ return make_ret(result)
+
+
+cdef class FieldSetter:
+ cdef dict __dict__
+ cdef TVMFFIFieldSetter setter
+ cdef int64_t offset
+
+ def __call__(self, Object obj, value):
+ cdef int c_api_ret_code
+ cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
+ TVMFFIPyCallFieldSetter(
+ TVMFFIPyArgSetterFactory_,
+ self.setter,
+ field_ptr,
+ <PyObject*>value,
+ &c_api_ret_code
+ )
+ # NOTE: logic is same as check_call
+ # directly inline here to simplify traceback
+ if c_api_ret_code == 0:
+ return
+ elif c_api_ret_code == -2:
+ raise_existing_error()
+ raise move_from_last_error().py_error()
+
+
[email protected](eq=False)
+class TypeField:
+ """Description of a single reflected field on an FFI-backed type."""
+
+ name: str
+ doc: str | None
+ size: int
+ offset: int
+ frozen: bool
+ getter: FieldGetter
+ setter: FieldSetter
+
+ def __post_init__(self):
+ assert self.setter is not None
+ assert self.getter is not None
+
+ def as_property(self, cls: type) -> property:
+ """Create a Python ``property`` object for this field on ``cls``."""
+ name = self.name
+ fget = self.getter
+ fset = self.setter
+ fget.__name__ = fset.__name__ = name
+ fget.__module__ = fset.__module__ = cls.__module__
+ fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
# type: ignore[attr-defined]
+ fget.__doc__ = fset.__doc__ = f"Property `{name}` of class
`{cls.__qualname__}`" # type: ignore[attr-defined]
+
+ return property(
+ fget=fget if self.getter is not None else None,
+ fset=fset if (not self.frozen) and self.setter is not None else
None,
+ doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
+ )
+
+
[email protected](eq=False)
+class TypeMethod:
+ """Description of a single reflected method on an FFI-backed type."""
+
+ name: str
+ doc: str | None
+ func: object
+ is_static: bool
+
+
[email protected](eq=False)
+class TypeInfo:
+ """Aggregated type information required to build a proxy class."""
+
+ type_cls: type | None
+ type_index: int
+ type_key: str
+ fields: list[TypeField]
+ methods: list[TypeMethod]
+ parent_type_info: TypeInfo | None
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 1f4e340..5a540fb 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -20,6 +20,7 @@ import sys
from typing import Any, Callable, Optional
from . import core
+from .core import TypeInfo
# whether we simplify skip unknown objects regtistration
_SKIP_UNKNOWN_OBJECTS = False
@@ -54,8 +55,8 @@ def register_object(type_key: str | type | None = None) ->
Any:
if _SKIP_UNKNOWN_OBJECTS:
return cls
raise ValueError(f"Cannot find object type index for
{object_name}")
- core._add_class_attrs_by_reflection(type_index, cls)
- core._register_object_by_index(type_index, cls)
+ info = core._register_object_by_index(type_index, cls)
+ _add_class_attrs(type_cls=cls, type_info=info)
return cls
if isinstance(type_key, str):
@@ -228,6 +229,46 @@ def init_ffi_api(namespace: str, target_module_name:
Optional[str] = None) -> No
setattr(target_module, f.__name__, f)
+def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
+ def wrapper(self: Any, *args: Any) -> Any:
+ return method_func(self, *args)
+
+ return wrapper
+
+
+def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
+ for field in type_info.fields:
+ getter = field.getter
+ setter = field.setter if not field.frozen else None
+ doc = field.doc if field.doc else None
+ name = field.name
+ if hasattr(type_cls, name):
+ # skip already defined attributes
+ continue
+ setattr(type_cls, name, property(getter, setter, doc=doc))
+ for method in type_info.methods:
+ name = method.name
+ doc = method.doc if method.doc else None
+ method_func = method.func
+ if method.is_static:
+ method_pyfunc = staticmethod(method_func)
+ else:
+ # must call into another method instead of direct capture
+ # to avoid the same method_func variable being used
+ # across multiple loop iterations
+ method_pyfunc = _member_method_wrapper(method_func)
+
+ if doc is not None:
+ method_pyfunc.__doc__ = doc
+ method_pyfunc.__name__ = name
+
+ if hasattr(type_cls, name):
+ # skip already defined attributes
+ continue
+ setattr(type_cls, name, method_pyfunc)
+ return type_cls
+
+
__all__ = [
"get_global_func",
"init_ffi_api",