This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 98cb8af feat: Auto-create Python classes when missing (#49)
98cb8af is described below
commit 98cb8af49ff599c217fce96c3d4f57c0f52b8ec4
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 25 12:11:04 2025 -0700
feat: Auto-create Python classes when missing (#49)
---
include/tvm/ffi/c_api.h | 4 +-
include/tvm/ffi/object.h | 2 +-
include/tvm/ffi/reflection/accessor.h | 4 +-
python/tvm_ffi/core.pyi | 10 ++-
python/tvm_ffi/cython/base.pxi | 2 +-
python/tvm_ffi/cython/object.pxi | 121 ++++++++++++++++++++++++----------
python/tvm_ffi/cython/string.pxi | 1 +
python/tvm_ffi/cython/type_info.pxi | 54 ++++++++++++---
python/tvm_ffi/dataclasses/_utils.py | 12 ----
python/tvm_ffi/dataclasses/c_class.py | 15 ++---
python/tvm_ffi/registry.py | 36 ++--------
src/ffi/extra/reflection_extra.cc | 2 +-
src/ffi/extra/testing.cc | 38 +++++++++--
src/ffi/object.cc | 16 ++---
tests/cpp/test_object.cc | 4 +-
tests/python/test_object.py | 27 +++++---
16 files changed, 218 insertions(+), 130 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index e988e24..62d9001 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -917,11 +917,11 @@ typedef struct TVMFFITypeInfo {
/*! \brief the unique type key to identify the type. */
TVMFFIByteArray type_key;
/*!
- * \brief type_acenstors[depth] stores the type_index of the acenstors at
depth level
+ * \brief type_ancestors[depth] stores the type_index of the acenstors at
depth level
* \note To keep things simple, we do not allow multiple inheritance so the
* hieracy stays as a tree
*/
- const struct TVMFFITypeInfo** type_acenstors;
+ const struct TVMFFITypeInfo** type_ancestors;
// The following fields are used for reflection
/*! \brief Cached hash value of the type key, used for consistent structural
hashing. */
uint64_t type_key_hash;
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 1ebd3d7..93ef6d8 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -1012,7 +1012,7 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t
object_type_index) {
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
return (type_info->type_depth > TargetType::_type_depth &&
- type_info->type_acenstors[TargetType::_type_depth]->type_index
== target_type_index);
+ type_info->type_ancestors[TargetType::_type_depth]->type_index
== target_type_index);
} else {
return false;
}
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index 5fadd09..b49da51 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -216,7 +216,7 @@ inline void ForEachFieldInfo(const TypeInfo* type_info,
Callback callback) {
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
- const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+ const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
for (int j = 0; j < parent_info->num_fields; ++j) {
callback(parent_info->fields + j);
}
@@ -243,7 +243,7 @@ inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo*
type_info,
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
- const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+ const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
for (int j = 0; j < parent_info->num_fields; ++j) {
if (callback_with_early_stop(parent_info->fields + j)) return true;
}
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 53c54dd..ac04114 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -46,7 +46,7 @@ class Object:
def __hash__(self) -> int: ...
def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) ->
None: ...
def __ffi_init__(self, *args: Any) -> None:
- """Initialize the instance using the ` __init__` method registered on
C++ side.
+ """Initialize the instance using the ` __ffi_init__` method registered
on C++ side.
Parameters
----------
@@ -83,8 +83,8 @@ class PyNativeObject:
def _set_class_object(cls: type) -> None: ...
def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
def _object_type_key_to_index(type_key: str) -> int | None: ...
-def _set_type_cls(type_index: int, type_cls: type) -> None: ...
-def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ...
+def _set_type_cls(type_info: TypeInfo, type_cls: type) -> None: ...
+def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
class Error(Object):
"""Base class for FFI errors."""
@@ -267,6 +267,8 @@ class TypeMethod:
func: Any
is_static: bool
+ def as_callable(self, cls: type) -> Callable[..., Any]: ...
+
class TypeInfo:
"""Aggregated type information required to build a proxy class."""
@@ -276,3 +278,5 @@ class TypeInfo:
fields: list[TypeField]
methods: list[TypeMethod]
parent_type_info: TypeInfo | None
+
+ def prototype_py(self) -> str: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 3ff3ecc..5c1ba1e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -191,7 +191,7 @@ cdef extern from "tvm/ffi/c_api.h":
int32_t type_index
int32_t type_depth
TVMFFIByteArray type_key
- const int32_t* type_acenstors
+ const TVMFFITypeInfo** type_ancestors
uint64_t type_key_hash
int32_t num_fields
int32_t num_methods
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index e8f3593..0bb1e03 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import warnings
+from typing import Any
+
_CLASS_OBJECT = None
@@ -261,29 +263,66 @@ cdef inline object make_ret_opaque_object(TVMFFIAny
result):
(<Object>obj).chandle = result.v_obj
return obj.pyobject()
+cdef inline object make_fallback_cls_for_type_index(int32_t type_index):
+ cdef str type_key = _type_index_to_key(type_index)
+ cdef object type_info =
_lookup_or_register_type_info_from_type_key(type_key)
+ cdef object parent_type_info = type_info.parent_type_info
+ assert type_info.type_cls is None
+
+ # Ensure parent classes are created first
+ assert parent_type_info is not None
+ if parent_type_info.type_cls is None: # recursively create parent class
first
+ make_fallback_cls_for_type_index(parent_type_info.type_index)
+ assert parent_type_info.type_cls is not None
+
+ # Create `type_info.type_cls` now
+ class cls(parent_type_info.type_cls):
+ pass
+ attrs = dict(cls.__dict__)
+ attrs.pop("__dict__", None)
+ attrs.pop("__weakref__", None)
+ attrs.update({
+ "__slots__": (),
+ "__tvm_ffi_type_info__": type_info,
+ "__name__": type_key.split(".")[-1],
+ "__qualname__": type_key,
+ "__module__": ".".join(type_key.split(".")[:-1]),
+ "__doc__": f"Auto-generated fallback class for {type_key}.\n"
+ "This class is generated because the class is not
registered.\n"
+ "Please do not use this class directly, instead register
the class\n"
+ "using `register_object` decorator.",
+ })
+ for field in type_info.fields:
+ attrs[field.name] = field.as_property(cls)
+ for method in type_info.methods:
+ name = method.name
+ if name == "__ffi_init__":
+ name = "__c_ffi_init__"
+ attrs[name] = method.as_callable(cls)
+ for name, val in attrs.items():
+ setattr(cls, name, val)
+ # Update the registry
+ type_info.type_cls = cls
+ _update_registry(type_index, type_key, type_info, cls)
+ return cls
+
cdef inline object make_ret_object(TVMFFIAny result):
- global TYPE_INDEX_TO_INFO
- cdef int32_t tindex
- cdef object cls
- tindex = result.type_index
-
- if tindex < len(TYPE_INDEX_TO_CLS):
- cls = TYPE_INDEX_TO_CLS[tindex]
- if cls is not None:
- if issubclass(cls, PyNativeObject):
- obj = Object.__new__(Object)
- (<Object>obj).chandle = result.v_obj
- return cls.__from_tvm_ffi_object__(cls, obj)
- obj = cls.__new__(cls)
- (<Object>obj).chandle = result.v_obj
- return obj
+ cdef int32_t type_index
+ cdef object cls, obj
+ type_index = result.type_index
- # object is not found in registered entry
- # in this case we need to report an warning
- type_key = _type_index_to_key(tindex)
- warnings.warn(f"Returning type `{type_key}` which is not registered via
register_object, fallback to Object")
- obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
+ if type_index < len(TYPE_INDEX_TO_CLS) and (cls :=
TYPE_INDEX_TO_CLS[type_index]) is not None:
+ if issubclass(cls, PyNativeObject):
+ obj = Object.__new__(Object)
+ (<Object>obj).chandle = result.v_obj
+ return cls.__from_tvm_ffi_object__(cls, obj)
+ else:
+ # Slow path: object is not found in registered entry
+ # In this case create a dummy stub class for future usage.
+ # For every unregistered class, this slow path will be triggered only
once.
+ cls = make_fallback_cls_for_type_index(type_index)
+ obj = cls.__new__(cls)
(<Object>obj).chandle = result.v_obj
return obj
@@ -294,17 +333,21 @@ cdef _get_method_from_method_info(const TVMFFIMethodInfo*
method):
return make_ret(result)
-def _type_info_create_from_type_key(object type_cls, str type_key):
+cdef _type_info_create_from_type_key(object type_cls, str type_key):
cdef const TVMFFIFieldInfo* field
cdef const TVMFFIMethodInfo* method
cdef const TVMFFITypeInfo* info
cdef int32_t type_index
+ cdef list ancestors = []
+ cdef int ancestor
cdef object fields = []
cdef object methods = []
cdef FieldGetter getter
cdef FieldSetter setter
cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key))
+ # NOTE: `type_key_arg` must be kept alive until after the call to
`TVMFFITypeKeyToIndex`,
+ # because Cython doesn't defer the destruction of `type_key_arg` until
after the call.
if TVMFFITypeKeyToIndex(type_key_arg.cptr(), &type_index) != 0:
raise ValueError(f"Cannot find type key: {type_key}")
info = TVMFFIGetTypeInfo(type_index)
@@ -339,44 +382,54 @@ def _type_info_create_from_type_key(object type_cls, str
type_key):
)
)
+ for i in range(info.type_depth):
+ ancestor = info.type_ancestors[i].type_index
+ ancestors.append(ancestor)
+
return TypeInfo(
type_cls=type_cls,
type_index=type_index,
type_key=bytearray_to_str(&info.type_key),
+ type_ancestors=ancestors,
fields=fields,
methods=methods,
parent_type_info=None,
)
-def _register_object_by_index(int type_index, object type_cls):
- global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO, TYPE_INDEX_TO_CLS
- cdef str type_key = _type_index_to_key(type_index)
- cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+cdef _update_registry(int type_index, object type_key, object type_info,
object type_cls):
+ cdef int extra = type_index + 1 - len(TYPE_INDEX_TO_INFO)
assert len(TYPE_INDEX_TO_INFO) == len(TYPE_INDEX_TO_CLS)
- if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0:
+ if extra > 0:
TYPE_INDEX_TO_INFO.extend([None] * extra)
TYPE_INDEX_TO_CLS.extend([None] * extra)
TYPE_INDEX_TO_CLS[type_index] = type_cls
- TYPE_INDEX_TO_INFO[type_index] = info
- TYPE_KEY_TO_INFO[type_key] = info
+ TYPE_INDEX_TO_INFO[type_index] = type_info
+ TYPE_KEY_TO_INFO[type_key] = type_info
+
+
+def _register_object_by_index(int type_index, object type_cls):
+ global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO, TYPE_INDEX_TO_CLS
+ cdef str type_key = _type_index_to_key(type_index)
+ cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+ _update_registry(type_index, type_key, info, type_cls)
return info
-def _set_type_cls(int type_index, object type_cls):
+def _set_type_cls(object type_info, object type_cls):
global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS
- assert len(TYPE_INDEX_TO_INFO) == len(TYPE_INDEX_TO_CLS)
- type_info = TYPE_INDEX_TO_INFO[type_index]
assert type_info.type_cls is None, f"Type already registered for
{type_info.type_key}"
+ assert TYPE_INDEX_TO_INFO[type_info.type_index] is type_info
+ assert TYPE_KEY_TO_INFO[type_info.type_key] is type_info
type_info.type_cls = type_cls
- TYPE_INDEX_TO_CLS[type_index] = type_cls
+ TYPE_INDEX_TO_CLS[type_info.type_index] = type_cls
-def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo:
+def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo:
if info := TYPE_KEY_TO_INFO.get(type_key, None):
return info
info = _type_info_create_from_type_key(None, type_key)
- TYPE_KEY_TO_INFO[type_key] = info
+ _update_registry(info.type_index, type_key, info, None)
return info
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index 0f9d11b..bb8f6e5 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -77,3 +77,4 @@ class Bytes(bytes, PyNativeObject):
_register_object_by_index(kTVMFFIBytes, Bytes)
+_register_object_by_index(kTVMFFIObject, Object)
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index fcd443b..ca893b4 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -16,6 +16,7 @@
# under the License.
import dataclasses
from typing import Optional, Any
+from io import StringIO
cdef class FieldGetter:
@@ -75,20 +76,20 @@ class TypeField:
assert self.setter is not None
assert self.getter is not None
- def as_property(self, cls: type):
+ def as_property(self, object cls):
"""Create a Python ``property`` object for this field on ``cls``."""
- name = self.name
- fget = self.getter
- fset = self.setter
+ cdef str name = self.name
+ cdef str doc = self.doc or
f"{cls.__module__}.{cls.__qualname__}.{name}"
+ cdef FieldGetter fget = self.getter
+ cdef FieldSetter fset = self.setter
fget.__name__ = fset.__name__ = name
fget.__module__ = fset.__module__ = cls.__module__
- fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
# type: ignore[attr-defined]
- fget.__doc__ = fset.__doc__ = f"Property `{name}` of class
`{cls.__qualname__}`" # type: ignore[attr-defined]
-
+ fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
+ fget.__doc__ = fset.__doc__ = f"Property `{name}` of class
`{cls.__qualname__}`"
return property(
- fget=fget if self.getter is not None else None,
- fset=fset if (not self.frozen) and self.setter is not None else
None,
- doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
+ fget=fget,
+ fset=fset if (not self.frozen) else None,
+ doc=doc,
)
@@ -101,6 +102,21 @@ class TypeMethod:
func: object
is_static: bool
+ def as_callable(self, object cls):
+ """Create a Python method attribute for this method on ``cls``."""
+ cdef str name = self.name
+ cdef str doc = self.doc or f"Method `{name}` of class
`{cls.__qualname__}`"
+ cdef object func = self.func
+ if not self.is_static:
+ func = _member_method_wrapper(func)
+ func.__module__ = cls.__module__
+ func.__name__ = name
+ func.__qualname__ = f"{cls.__qualname__}.{name}"
+ func.__doc__ = doc
+ if self.is_static:
+ func = staticmethod(func)
+ return func
+
@dataclasses.dataclass(eq=False)
class TypeInfo:
@@ -109,6 +125,24 @@ class TypeInfo:
type_cls: Optional[type]
type_index: int
type_key: str
+ type_ancestors: list[int]
fields: list[TypeField]
methods: list[TypeMethod]
parent_type_info: Optional[TypeInfo]
+
+ def __post_init__(self):
+ cdef int parent_type_index
+ cdef str parent_type_key
+ if not self.type_ancestors:
+ return
+ parent_type_index = self.type_ancestors[-1]
+ parent_type_key = _type_index_to_key(parent_type_index)
+ # ensure parent is registered
+ self.parent_type_info =
_lookup_or_register_type_info_from_type_key(parent_type_key)
+
+
+def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
+ def wrapper(self: Any, *args: Any) -> Any:
+ return method_func(self, *args)
+
+ return wrapper
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index b6bcdac..60f31fb 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -26,23 +26,11 @@ from ..core import (
Object,
TypeField,
TypeInfo,
- _lookup_type_info_from_type_key,
)
_InputClsType = TypeVar("_InputClsType")
-def get_parent_type_info(type_cls: type) -> TypeInfo:
- """Find the nearest ancestor with registered ``__tvm_ffi_type_info__``.
-
- If none are found, return the base ``ffi.Object`` type info.
- """
- for base in type_cls.__bases__:
- if (info := getattr(base, "__tvm_ffi_type_info__", None)) is not None:
- return info
- return _lookup_type_info_from_type_key("ffi.Object")
-
-
def type_info_to_cls(
type_info: TypeInfo,
cls: type[_InputClsType],
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index dc6aeed..42dc4fd 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -27,16 +27,16 @@ from collections.abc import Callable
from dataclasses import InitVar
from typing import ClassVar, TypeVar, get_origin, get_type_hints
-from typing_extensions import dataclass_transform # type: ignore[attr-defined]
+from typing_extensions import dataclass_transform
-from ..core import TypeField, TypeInfo
+from ..core import TypeField, TypeInfo,
_lookup_or_register_type_info_from_type_key, _set_type_cls
from . import _utils
-from .field import Field, field
+from .field import field
_InputClsType = TypeVar("_InputClsType")
-@dataclass_transform(field_specifiers=(field, Field))
+@dataclass_transform(field_specifiers=(field,))
def c_class(
type_key: str, init: bool = True
) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
@@ -116,9 +116,8 @@ def c_class(
nonlocal init
init = init and "__init__" not in super_type_cls.__dict__
# Step 1. Retrieve `type_info` from registry
- type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key)
- assert type_info.parent_type_info is None, f"Already registered type:
{type_key}"
- type_info.parent_type_info =
_utils.get_parent_type_info(super_type_cls)
+ type_info: TypeInfo =
_lookup_or_register_type_info_from_type_key(type_key)
+ assert type_info.parent_type_info is not None
# Step 2. Reflect all the fields of the type
type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
for type_field in type_info.fields:
@@ -130,7 +129,7 @@ def c_class(
cls=super_type_cls,
methods={"__init__": fn_init},
)
- type_info.type_cls = type_cls
+ _set_type_cls(type_info, type_cls)
return type_cls
return decorator
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 3ef4039..45c1da0 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -249,45 +249,17 @@ def init_ffi_api(namespace: str, target_module_name: str
| None = None) -> None:
setattr(target_module, fname, f)
-def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
- def wrapper(self: Any, *args: Any) -> Any:
- return method_func(self, *args)
-
- return wrapper
-
-
def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
for field in type_info.fields:
- getter = field.getter
- setter = field.setter if not field.frozen else None
- doc = field.doc if field.doc else None
name = field.name
- if hasattr(type_cls, name):
- # skip already defined attributes
- continue
- setattr(type_cls, name, property(getter, setter, doc=doc))
+ if not hasattr(type_cls, name): # skip already defined attributes
+ setattr(type_cls, name, field.as_property(type_cls))
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
- doc = method.doc if method.doc else None
- method_func = method.func
- if method.is_static:
- if doc is not None:
- method_func.__doc__ = doc
- method_func.__name__ = name
- method_pyfunc: Any = staticmethod(method_func)
- else:
- wrapped_func = _member_method_wrapper(method_func)
- if doc is not None:
- wrapped_func.__doc__ = doc
- wrapped_func.__name__ = name
- method_pyfunc = wrapped_func
-
- if hasattr(type_cls, name):
- # skip already defined attributes
- continue
- setattr(type_cls, name, method_pyfunc)
+ if not hasattr(type_cls, name):
+ setattr(type_cls, name, method.as_callable(type_cls))
return type_cls
diff --git a/src/ffi/extra/reflection_extra.cc
b/src/ffi/extra/reflection_extra.cc
index f923643..d36e1ac 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -90,7 +90,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret)
{
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
- update_fields(type_info->type_acenstors[i]);
+ update_fields(type_info->type_ancestors[i]);
}
update_fields(type_info);
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index cf55161..f67752f 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -134,13 +134,23 @@ class TestCxxInitSubsetObj : public Object {
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset",
TestCxxInitSubsetObj, Object);
};
-class TestUnregisteredObject : public Object {
+class TestUnregisteredBaseObject : public Object {
public:
- int64_t value;
-
- explicit TestUnregisteredObject(int64_t value) : value(value) {}
+ int64_t v1;
+ explicit TestUnregisteredBaseObject(int64_t v1) : v1(v1) {}
+ int64_t GetV1PlusOne() const { return v1 + 1; }
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredBaseObject",
TestUnregisteredBaseObject,
+ Object);
+};
- TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject",
TestUnregisteredObject, Object);
+class TestUnregisteredObject : public TestUnregisteredBaseObject {
+ public:
+ int64_t v2;
+ explicit TestUnregisteredObject(int64_t v1, int64_t v2)
+ : TestUnregisteredBaseObject(v1), v2(v2) {}
+ int64_t GetV2PlusTwo() const { return v2 + 2; }
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject",
TestUnregisteredObject,
+ TestUnregisteredBaseObject);
};
TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
@@ -176,6 +186,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_static("__ffi_init__", refl::init<TestCxxClassDerived, int64_t,
int32_t, double, float>)
.def_rw("v_f64", &TestCxxClassDerived::v_f64)
.def_rw("v_f32", &TestCxxClassDerived::v_f32);
+
refl::ObjectDef<TestCxxClassDerivedDerived>()
.def_static(
"__ffi_init__",
@@ -189,6 +200,21 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
.def_rw("note", &TestCxxInitSubsetObj::note);
+ refl::ObjectDef<TestUnregisteredBaseObject>()
+ .def_ro("v1", &TestUnregisteredBaseObject::v1)
+ .def_static("__ffi_init__", refl::init<TestUnregisteredBaseObject,
int64_t>,
+ "Constructor of TestUnregisteredBaseObject")
+ .def("get_v1_plus_one", &TestUnregisteredBaseObject::GetV1PlusOne,
+ "Get (v1 + 1) from TestUnregisteredBaseObject");
+
+ refl::ObjectDef<TestUnregisteredObject>()
+ .def_ro("v1", &TestUnregisteredObject::v1)
+ .def_ro("v2", &TestUnregisteredObject::v2)
+ .def_static("__ffi_init__", refl::init<TestUnregisteredObject, int64_t,
int64_t>,
+ "Constructor of TestUnregisteredObject")
+ .def("get_v2_plus_two", &TestUnregisteredObject::GetV2PlusTwo,
+ "Get (v2 + 2) from TestUnregisteredObject");
+
refl::GlobalDef()
.def("testing.test_raise_error", TestRaiseError)
.def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
@@ -206,7 +232,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
})
.def("testing.object_use_count", [](const Object* obj) { return
obj->use_count(); })
.def("testing.make_unregistered_object",
- []() { return ObjectRef(make_object<TestUnregisteredObject>(42));
});
+ []() { return ObjectRef(make_object<TestUnregisteredObject>(41,
42)); });
}
} // namespace ffi
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 292c8e9..d9c6698 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -54,7 +54,7 @@ class TypeTable {
/*! \brief stored type key */
String type_key_data;
/*! \brief acenstor information */
- std::vector<const TVMFFITypeInfo*> type_acenstors_data;
+ std::vector<const TVMFFITypeInfo*> type_ancestors_data;
/*! \brief type fields informaton */
std::vector<TVMFFIFieldInfo> type_fields_data;
/*! \brief type methods informaton */
@@ -81,21 +81,21 @@ class TypeTable {
if (type_depth != 0) {
TVM_FFI_ICHECK_NOTNULL(parent);
TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1);
- type_acenstors_data.resize(type_depth);
+ type_ancestors_data.resize(type_depth);
// copy over parent's type information
for (int32_t i = 0; i < parent->type_depth; ++i) {
- type_acenstors_data[i] = parent->type_acenstors[i];
+ type_ancestors_data[i] = parent->type_ancestors[i];
}
// set last type information to be parent
- type_acenstors_data[parent->type_depth] = parent;
+ type_ancestors_data[parent->type_depth] = parent;
}
- // initialize type info: no change to type_key and type_acenstors fields
+ // initialize type info: no change to type_key and type_ancestors fields
// after this line
this->type_index = type_index;
this->type_depth = type_depth;
this->type_key = TVMFFIByteArray{this->type_key_data.data(),
this->type_key_data.length()};
this->type_key_hash = std::hash<String>()(this->type_key_data);
- this->type_acenstors = type_acenstors_data.data();
+ this->type_ancestors = type_ancestors_data.data();
// initialize the reflection information
this->num_fields = 0;
this->num_methods = 0;
@@ -280,7 +280,7 @@ class TypeTable {
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
const Entry* ptr = it->get();
if (ptr != nullptr && ptr->type_depth != 0) {
- int parent_index = ptr->type_acenstors[ptr->type_depth -
1]->type_index;
+ int parent_index = ptr->type_ancestors[ptr->type_depth -
1]->type_index;
num_children[parent_index] += num_children[ptr->type_index] + 1;
if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) {
expected_child_slots[ptr->type_index] = ptr->num_slots - 1;
@@ -293,7 +293,7 @@ class TypeTable {
if (ptr != nullptr && num_children[ptr->type_index] >=
min_children_count) {
std::cerr << '[' << ptr->type_index << "]\t" <<
ToStringView(ptr->type_key);
if (ptr->type_depth != 0) {
- int32_t parent_index = ptr->type_acenstors[ptr->type_depth -
1]->type_index;
+ int32_t parent_index = ptr->type_ancestors[ptr->type_depth -
1]->type_index;
std::cerr << "\tparent=" <<
ToStringView(type_table_[parent_index]->type_key);
} else {
std::cerr << "\tparent=root";
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index ec5c54c..6c8c822 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -55,8 +55,8 @@ TEST(Object, TypeInfo) {
EXPECT_TRUE(info != nullptr);
EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex());
EXPECT_EQ(info->type_depth, 2);
- EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index);
- EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index);
+ EXPECT_EQ(info->type_ancestors[0]->type_index, Object::_type_index);
+ EXPECT_EQ(info->type_ancestors[1]->type_index, TNumberObj::_type_index);
EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
}
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 3b36f5b..0c64e46 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -19,6 +19,7 @@ from typing import Any
import pytest
import tvm_ffi
+from tvm_ffi.core import TypeInfo
def test_make_object() -> None:
@@ -103,12 +104,22 @@ def test_opaque_object() -> None:
def test_unregistered_object_fallback() -> None:
- with pytest.warns(
- UserWarning,
- match=(
- r"Returning type `testing\.TestUnregisteredObject` "
- r"which is not registered via register_object, fallback to Object"
- ),
- ):
+ def _check_type(x: Any) -> None:
+ type_info: TypeInfo = type(x).__tvm_ffi_type_info__ # type:
ignore[attr-defined]
+ assert type_info.type_key == "testing.TestUnregisteredObject"
+ assert x.v1 == 41
+ assert x.v2 == 42
+ assert x.get_v1_plus_one() == 42 # type: ignore[attr-defined]
+ assert x.get_v2_plus_two() == 44 # type: ignore[attr-defined]
+ assert type(x).__name__ == "TestUnregisteredObject"
+ assert type(x).__module__ == "testing"
+ assert type(x).__qualname__ == "testing.TestUnregisteredObject"
+ assert "Auto-generated fallback class" in type(x).__doc__ # type:
ignore[operator]
+ assert "Get (v1 + 1) from TestUnregisteredBaseObject" in
type(x).get_v1_plus_one.__doc__ # type: ignore[attr-defined]
+ assert "Get (v2 + 2) from TestUnregisteredObject" in
type(x).get_v2_plus_two.__doc__ # type: ignore[attr-defined]
+
+ obj = tvm_ffi.testing.make_unregistered_object()
+ _check_type(obj)
+ for _ in range(5):
obj = tvm_ffi.testing.make_unregistered_object()
- assert type(obj) is tvm_ffi.Object
+ _check_type(obj)