This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 0729193  feat: Automatically add `__init__` method if available (#174)
0729193 is described below

commit 0729193f475c7ab1059524fcfa6ffc742b0addac
Author: Junru Shao <[email protected]>
AuthorDate: Sun Oct 19 12:28:15 2025 -0700

    feat: Automatically add `__init__` method if available (#174)
    
    Previously, when `SomeObject.__init__` is not defined,
    `SomeObject(anything, ...)` will silently do nothing but returns a
    `SomeObject` with `chandle=None`, which triggers further segfault when
    trying to access its fields.
    
    This PR fixes this issue by auto-generating `__init__` when it's not
    available:
    - if `__ffi_init__` method is defined, generate a method that calls
    `__ffi_init__`;
    - if not, generate an `__init__` method that explicitly errors out
    saying it's undefined
---
 python/tvm_ffi/cython/object.pxi |  2 +-
 python/tvm_ffi/registry.py       | 11 +++++++++++
 python/tvm_ffi/testing.py        |  4 ----
 tests/python/test_object.py      |  4 ++--
 4 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index a8a73b0..499a032 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -278,7 +278,7 @@ cdef inline object make_fallback_cls_for_type_index(int32_t 
type_index):
     # Create `type_info.type_cls` now
     class cls(parent_type_info.type_cls):
         pass
-    attrs = dict(cls.__dict__)
+    attrs = cls.__dict__.copy()
     attrs.pop("__dict__", None)
     attrs.pop("__weakref__", None)
     attrs.update({
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 3080cce..bd37ebd 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -273,15 +273,26 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) 
-> type:
         name = field.name
         if not hasattr(type_cls, name):  # skip already defined attributes
             setattr(type_cls, name, field.as_property(type_cls))
+    has_c_init = False
     for method in type_info.methods:
         name = method.name
         if name == "__ffi_init__":
             name = "__c_ffi_init__"
+            has_c_init = True
         if not hasattr(type_cls, name):
             setattr(type_cls, name, method.as_callable(type_cls))
+    if "__init__" not in type_cls.__dict__:
+        if has_c_init:
+            setattr(type_cls, "__init__", getattr(type_cls, "__ffi_init__"))
+        elif not issubclass(type_cls, core.PyNativeObject):
+            setattr(type_cls, "__init__", __init__invalid)
     return type_cls
 
 
+def __init__invalid(self: Any, *args: Any, **kwargs: Any) -> None:
+    raise RuntimeError("The __init__ method of this class is not implemented.")
+
+
 __all__ = [
     "get_global_func",
     "get_global_func_metadata",
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 820fd60..ce9205f 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -61,10 +61,6 @@ class TestIntPair(Object):
         # fmt: on
     # tvm-ffi-stubgen(end)
 
-    def __init__(self, a: int, b: int) -> None:
-        """Construct the object."""
-        self.__ffi_init__(a, b)
-
 
 @register_object("testing.TestObjectDerived")
 class TestObjectDerived(TestObjectBase):
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index ebe8043..4701a74 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -32,7 +32,7 @@ def test_make_object() -> None:
 
 
 def test_make_object_via_init() -> None:
-    obj0 = tvm_ffi.testing.TestIntPair(1, 2)
+    obj0 = tvm_ffi.testing.TestIntPair(1, 2)  # type: ignore[call-arg]
     assert obj0.a == 1
     assert obj0.b == 2
 
@@ -46,7 +46,7 @@ def test_method() -> None:
 
 
 def test_attribute() -> None:
-    obj = tvm_ffi.testing.TestIntPair(3, 4)
+    obj = tvm_ffi.testing.TestIntPair(3, 4)  # type: ignore[call-arg]
     assert obj.a == 3
     assert obj.b == 4
     assert type(obj).a.__doc__ == "Field `a`"

Reply via email to