This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 0729193 feat: Automatically add `__init__` method if available (#174)
0729193 is described below
commit 0729193f475c7ab1059524fcfa6ffc742b0addac
Author: Junru Shao <[email protected]>
AuthorDate: Sun Oct 19 12:28:15 2025 -0700
feat: Automatically add `__init__` method if available (#174)
Previously, when `SomeObject.__init__` is not defined,
`SomeObject(anything, ...)` will silently do nothing but returns a
`SomeObject` with `chandle=None`, which triggers further segfault when
trying to access its fields.
This PR fixes this issue by auto-generating `__init__` when it's not
available:
- if `__ffi_init__` method is defined, generate a method that calls
`__ffi_init__`;
- if not, generate an `__init__` method that explicitly errors out
saying it's undefined
---
python/tvm_ffi/cython/object.pxi | 2 +-
python/tvm_ffi/registry.py | 11 +++++++++++
python/tvm_ffi/testing.py | 4 ----
tests/python/test_object.py | 4 ++--
4 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index a8a73b0..499a032 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -278,7 +278,7 @@ cdef inline object make_fallback_cls_for_type_index(int32_t
type_index):
# Create `type_info.type_cls` now
class cls(parent_type_info.type_cls):
pass
- attrs = dict(cls.__dict__)
+ attrs = cls.__dict__.copy()
attrs.pop("__dict__", None)
attrs.pop("__weakref__", None)
attrs.update({
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 3080cce..bd37ebd 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -273,15 +273,26 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
name = field.name
if not hasattr(type_cls, name): # skip already defined attributes
setattr(type_cls, name, field.as_property(type_cls))
+ has_c_init = False
for method in type_info.methods:
name = method.name
if name == "__ffi_init__":
name = "__c_ffi_init__"
+ has_c_init = True
if not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
+ if "__init__" not in type_cls.__dict__:
+ if has_c_init:
+ setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
+ elif not issubclass(type_cls, core.PyNativeObject):
+ setattr(type_cls, "__init__", __init__invalid)
return type_cls
+def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None:
+ raise RuntimeError("The __init__ method of this class is not implemented.")
+
+
__all__ = [
"get_global_func",
"get_global_func_metadata",
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 820fd60..ce9205f 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -61,10 +61,6 @@ class TestIntPair(Object):
# fmt: on
# tvm-ffi-stubgen(end)
- def __init__(self, a: int, b: int) -> None:
- """Construct the object."""
- self.__ffi_init__(a, b)
-
@register_object("testing.TestObjectDerived")
class TestObjectDerived(TestObjectBase):
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index ebe8043..4701a74 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -32,7 +32,7 @@ def test_make_object() -> None:
def test_make_object_via_init() -> None:
- obj0 = tvm_ffi.testing.TestIntPair(1, 2)
+ obj0 = tvm_ffi.testing.TestIntPair(1, 2) # type: ignore[call-arg]
assert obj0.a == 1
assert obj0.b == 2
@@ -46,7 +46,7 @@ def test_method() -> None:
def test_attribute() -> None:
- obj = tvm_ffi.testing.TestIntPair(3, 4)
+ obj = tvm_ffi.testing.TestIntPair(3, 4) # type: ignore[call-arg]
assert obj.a == 3
assert obj.b == 4
assert type(obj).a.__doc__ == "Field `a`"