This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 8786670 feat: Add `tvm_ffi.Function.__init__` (#395)
8786670 is described below
commit 8786670af20e9005ec1dd50d7872ee9e2b8b89cd
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jan 10 09:29:52 2026 -0800
feat: Add `tvm_ffi.Function.__init__` (#395)
This change enables us to convert Python functions directly to
`tvm_ffi.Function` using its constructor. Example:
```python
def add(a: int, b: int) -> int:
return a + b
func = tvm_ffi.Function(add)
assert isinstance(func, tvm_ffi.Function)
assert func(1, 2) == 3
```
---
include/tvm/ffi/base_details.h | 16 ++++++++++++----
python/tvm_ffi/core.pyi | 1 +
python/tvm_ffi/cython/function.pxi | 21 +++++++++++++++++++++
tests/python/test_function.py | 9 +++++++++
4 files changed, 43 insertions(+), 4 deletions(-)
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index 7224ac1..a00d4a4 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -86,6 +86,7 @@
#else
#define TVM_FFI_FUNC_SIG __func__
#endif
+/// \endcond
#if defined(__GNUC__)
// gcc and clang and attribute constructor
@@ -114,14 +115,22 @@
return 0; \
}(); \
static void FnName()
-
+/// \endcond
+/*!
+ * \brief Macro that defines a block that will be called during static
initialization.
+ *
+ * \code{.cpp}
+ * TVM_FFI_STATIC_INIT_BLOCK() {
+ * RegisterFunctions();
+ * }
+ * \endcode
+ */
#define TVM_FFI_STATIC_INIT_BLOCK()
\
TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc,
__COUNTER__), \
TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg,
__COUNTER__))
-/// \endcond
#endif
-/*
+/*!
* \brief Define the default copy/move constructor and assign operator
* \param TypeName The class typename.
*/
@@ -313,5 +322,4 @@ using TypeSchema =
TypeSchemaImpl<std::remove_const_t<std::remove_reference_t<T>
} // namespace details
} // namespace ffi
} // namespace tvm
-/// \endcond
#endif // TVM_FFI_BASE_DETAILS_H_
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 2cee79c..1c3a3da 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -184,6 +184,7 @@ class DLTensorTestWrapper:
def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...
class Function(Object):
+ def __init__(self, func: Callable[..., Any]) -> None: ...
@property
def release_gil(self) -> bool: ...
@release_gil.setter
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 01a3366..67429ea 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -889,6 +889,27 @@ cdef class Function(Object):
def __cinit__(self) -> None:
self.c_release_gil = _RELEASE_GIL_BY_DEFAULT
+ def __init__(self, func: Callable[..., Any]) -> None:
+ """Initialize a Function from a Python callable.
+
+ This constructor allows creating a `tvm_ffi.Function` directly
+ from a Python function or another `tvm_ffi.Function` instance.
+
+ Parameters
+ ----------
+ func : Callable[..., Any]
+ The Python callable to wrap.
+ """
+ cdef TVMFFIObjectHandle chandle = NULL
+ if not callable(func):
+ raise TypeError(f"func must be callable, got {type(func)}")
+ if isinstance(func, Function):
+ chandle = (<Object>func).chandle
+ TVMFFIObjectIncRef(chandle)
+ else:
+ _convert_to_ffi_func_handle(func, &chandle)
+ self.chandle = chandle
+
property release_gil:
"""Whether calls release the Python GIL while executing."""
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 8a494fb..34d3bdc 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -153,6 +153,15 @@ def test_pyfunc_convert() -> None:
assert fapply(add, 1, 3.3) == 4.3
+def test_pyfunc_init() -> None:
+ def add(a: int, b: int) -> int:
+ return a + b
+
+ fadd = tvm_ffi.Function(add)
+ assert isinstance(fadd, tvm_ffi.Function)
+ assert fadd(1, 2) == 3
+
+
def test_global_func() -> None:
@tvm_ffi.register_global_func("mytest.echo")
def echo(x: Any) -> Any: