This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 59c91c1 [TEST] Add testcase about subclassing an ffi.Function (#140)
59c91c1 is described below
commit 59c91c17eb7ef4f24cf00faedc82f1a8e0fc53a3
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Oct 15 21:57:23 2025 -0400
[TEST] Add testcase about subclassing an ffi.Function (#140)
Sometimes it is helpful to subclass it and attach more metadata, this pr
adds a testcase about it.
---
tests/python/test_function.py | 34 ++++++++++++++++++++++++++++++++++
1 file changed, 34 insertions(+)
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index cf50dcf..cc932f5 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -263,3 +263,37 @@ def test_function_from_mlir_packed_safe_call() -> None:
assert sys.getrefcount(keep_alive) == base_ref_count + 1
fadd_one = None
assert sys.getrefcount(keep_alive) == base_ref_count
+
+
+def test_function_subclass() -> None:
+ class JitFunction:
+ def __init__(self, metadata: Any) -> None:
+ self.metadata = metadata
+
+ class MyFunction(tvm_ffi.Function, JitFunction):
+ def __init__(self, metadata: Any) -> None:
+ # Explicitly initialize the mixin. `super()` is not used because
`tvm_ffi.Function`
+ # is an extension type without a standard `__init__`.
+ JitFunction.__init__(self, metadata)
+
+ # When subclassing a Cython cdef class and overriding `__init__`,
+ # special methods like `__call__` may not be inherited automatically.
+ # This explicit assignment ensures the subclass remains callable.
+ __call__ = tvm_ffi.Function.__call__ # type: ignore
+
+ f = tvm_ffi.convert(lambda x: x)
+ assert isinstance(f, tvm_ffi.Function)
+ f_sub = MyFunction(128)
+ # move handle from f to f_sub an existing function
+ f_sub.__move_handle_from__(f)
+ assert isinstance(f_sub, MyFunction)
+ assert isinstance(f_sub, JitFunction)
+ assert f_sub.metadata == 128
+
+ y: int = f_sub(2)
+ assert y == 2
+ echo = tvm_ffi.get_global_func("testing.echo")
+ fechoed = echo(f_sub)
+ assert isinstance(fechoed, tvm_ffi.Function)
+ assert fechoed.__chandle__() == f_sub.__chandle__()
+ assert fechoed(10) == 10