This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8dcaec1  feat: Add `keep_module_alive` to `load_module` (#334)
8dcaec1 is described below

commit 8dcaec1fb47bf7873b105385b7d2808d51f6b342
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 11 21:24:41 2025 -0800

    feat: Add `keep_module_alive` to `load_module` (#334)
    
    Should fix #264, #322. Supersedes #331.
    
    This PR introduces a flag `keep_module_alive: bool = True` to the
    following APIs:
    - **`load_module`**
    - **`cpp.load_inline`**
    - **`cpp.load`**
    
    By default, the flag is set to `True`, meaning it is bound to
    `ModuleGlobals` inside `libtvm_ffi.so`. It retains a reference count to
    those modules, and Python side will not unload the module.
    
    Two private APIs are added:
    - `tvm_ffi._ffi_api.ModuleGlobalsAdd`: adds a module to global registry
    `ModuleGlobals`
    - `tvm_ffi._ffi_api.ModuleGlobalsRemove`: removes the module from global
    registry `ModuleGlobals`
---
 python/tvm_ffi/_ffi_api.py      |  4 +++
 python/tvm_ffi/cpp/extension.py | 60 ++++++++++++++++++++++++++---------------
 python/tvm_ffi/module.py        | 11 ++++++--
 src/ffi/extra/module.cc         | 33 ++++++++++++++++++++++-
 tests/python/test_build.py      |  9 -------
 tests/python/test_stl.py        | 43 +++++++++++------------------
 6 files changed, 99 insertions(+), 61 deletions(-)

diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index ab6704a..4b49716 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -58,6 +58,8 @@ if TYPE_CHECKING:
     def ModuleGetKind(_0: Module, /) -> str: ...
     def ModuleGetPropertyMask(_0: Module, /) -> int: ...
     def ModuleGetWriteFormats(_0: Module, /) -> Sequence[str]: ...
+    def ModuleGlobalsAdd(_0: Module, /) -> None: ...
+    def ModuleGlobalsRemove(_0: Module, /) -> None: ...
     def ModuleImplementsFunction(_0: Module, _1: str, _2: bool, /) -> bool: ...
     def ModuleImportModule(_0: Module, _1: Module, /) -> None: ...
     def ModuleInspectSource(_0: Module, _1: str, /) -> str: ...
@@ -101,6 +103,8 @@ __all__ = [
     "ModuleGetKind",
     "ModuleGetPropertyMask",
     "ModuleGetWriteFormats",
+    "ModuleGlobalsAdd",
+    "ModuleGlobalsRemove",
     "ModuleImplementsFunction",
     "ModuleImportModule",
     "ModuleInspectSource",
diff --git a/python/tvm_ffi/cpp/extension.py b/python/tvm_ffi/cpp/extension.py
index b03ce67..761f0dd 100644
--- a/python/tvm_ffi/cpp/extension.py
+++ b/python/tvm_ffi/cpp/extension.py
@@ -751,7 +751,7 @@ def build_inline(
         )
 
 
-def load_inline(
+def load_inline(  # noqa: PLR0913
     name: str,
     *,
     cpp_sources: Sequence[str] | str | None = None,
@@ -763,6 +763,7 @@ def load_inline(
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
     embed_cubin: Mapping[str, bytes] | None = None,
+    keep_module_alive: bool = True,
 ) -> Module:
     """Compile, build and load a C++/CUDA module from inline source code.
 
@@ -791,13 +792,13 @@ def load_inline(
 
     Parameters
     ----------
-    name: str
+    name
         The name of the tvm ffi module.
-    cpp_sources: Sequence[str] | str, optional
+    cpp_sources
         The C++ source code. It can be a list of sources or a single source.
-    cuda_sources: Sequence[str] | str, optional
+    cuda_sources
         The CUDA source code. It can be a list of sources or a single source.
-    functions: Mapping[str, str] | Sequence[str] | str, optional
+    functions
         The functions in cpp_sources or cuda_source that will be exported to 
the tvm ffi module. When a mapping is
         given, the keys are the names of the exported functions, and the 
values are docstrings for the functions
         (use an empty string to skip documentation for specific functions). 
When a sequence or a single string is given, they are
@@ -805,41 +806,48 @@ def load_inline(
         also be given as a string. When cpp_sources is given, the functions 
must be declared (not necessarily defined)
         in the cpp_sources. When cpp_sources is not given, the functions must 
be defined in the cuda_sources. If not
         specified, no function will be exported.
-    extra_cflags: Sequence[str], optional
+    extra_cflags
         The extra compiler flags for C++ compilation.
         The default flags are:
 
         - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
         - On Windows: ['/std:c++17', '/O2']
 
-    extra_cuda_cflags: Sequence[str], optional
+    extra_cuda_cflags
         The extra compiler flags for CUDA compilation.
 
-    extra_ldflags: Sequence[str], optional
+    extra_ldflags
         The extra linker flags.
         The default flags are:
 
         - On Linux/macOS: ['-shared']
         - On Windows: ['/DLL']
 
-    extra_include_paths: Sequence[str], optional
+    extra_include_paths
         The extra include paths.
 
-    build_directory: str, optional
+    build_directory
         The build directory. If not specified, a default tvm ffi cache 
directory will be used. By default, the
         cache directory is ``~/.cache/tvm-ffi``. You can also set the 
``TVM_FFI_CACHE_DIR`` environment variable to
         specify the cache directory.
 
-    embed_cubin: Mapping[str, bytes], optional
+    embed_cubin
         A mapping from CUBIN module names to CUBIN binary data. When provided, 
the CUBIN data will be embedded
         into the compiled shared library using objcopy, making it accessible 
via the TVM_FFI_EMBED_CUBIN macro.
         The keys should match the names used in TVM_FFI_EMBED_CUBIN calls in 
the C++ source code.
 
+    keep_module_alive
+        Whether to keep the module alive. If True, the module will be kept 
alive
+        for the duration of the program until libtvm_ffi.so is unloaded.
+
     Returns
     -------
     mod: Module
         The loaded tvm ffi module.
 
+    See Also
+    --------
+    :py:func:`tvm_ffi.load_module`
 
     Example
     -------
@@ -892,7 +900,8 @@ def load_inline(
             extra_include_paths=extra_include_paths,
             build_directory=build_directory,
             embed_cubin=embed_cubin,
-        )
+        ),
+        keep_module_alive=keep_module_alive,
     )
 
 
@@ -1045,6 +1054,7 @@ def load(
     extra_ldflags: Sequence[str] | None = None,
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
+    keep_module_alive: bool = True,
 ) -> Module:
     """Compile, build and load a C++/CUDA module from source files.
 
@@ -1071,48 +1081,55 @@ def load(
 
     Parameters
     ----------
-    name: str
+    name
         The name of the tvm ffi module.
-    cpp_files: Sequence[str] | str, optional
+    cpp_files
         The C++ source files to compile. It can be a list of file paths or a 
single file path. Both absolute and
         relative paths are supported.
-    cuda_files: Sequence[str] | str, optional
+    cuda_files
         The CUDA source files to compile. It can be a list of file paths or a 
single file path. Both absolute and
         relative paths are supported.
-    extra_cflags: Sequence[str], optional
+    extra_cflags
         The extra compiler flags for C++ compilation.
         The default flags are:
 
         - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
         - On Windows: ['/std:c++17', '/MD', '/O2']
 
-    extra_cuda_cflags: Sequence[str], optional
+    extra_cuda_cflags
         The extra compiler flags for CUDA compilation.
         The default flags are:
 
         - ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] (Linux/macOS)
         - ['-Xcompiler', '/std:c++17', '/O2'] (Windows)
 
-    extra_ldflags: Sequence[str], optional
+    extra_ldflags
         The extra linker flags.
         The default flags are:
 
         - On Linux/macOS: ['-shared', '-L<tvm_ffi_lib_path>', '-ltvm_ffi']
         - On Windows: ['/DLL', '/LIBPATH:<tvm_ffi_lib_path>', 
'<tvm_ffi_lib_name>.lib']
 
-    extra_include_paths: Sequence[str], optional
+    extra_include_paths
         The extra include paths for header files. Both absolute and relative 
paths are supported.
 
-    build_directory: str, optional
+    build_directory
         The build directory. If not specified, a default tvm ffi cache 
directory will be used. By default, the
         cache directory is ``~/.cache/tvm-ffi``. You can also set the 
``TVM_FFI_CACHE_DIR`` environment variable to
         specify the cache directory.
 
+    keep_module_alive
+        Whether to keep the module alive. If True, the module will be kept 
alive
+        for the duration of the program until libtvm_ffi.so is unloaded.
+
     Returns
     -------
     mod: Module
         The loaded tvm ffi module.
 
+    See Also
+    --------
+    :py:func:`tvm_ffi.load_module`
 
     Example
     -------
@@ -1169,5 +1186,6 @@ def load(
             extra_ldflags=extra_ldflags,
             extra_include_paths=extra_include_paths,
             build_directory=build_directory,
-        )
+        ),
+        keep_module_alive=keep_module_alive,
     )
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 7b548d6..0c04d21 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -427,7 +427,7 @@ def system_lib(symbol_prefix: str = "") -> Module:
     return _ffi_api.SystemLib(symbol_prefix)
 
 
-def load_module(path: str | PathLike) -> Module:
+def load_module(path: str | PathLike, keep_module_alive: bool = True) -> 
Module:
     """Load module from file.
 
     Parameters
@@ -435,6 +435,10 @@ def load_module(path: str | PathLike) -> Module:
     path
         The path to the module file.
 
+    keep_module_alive
+        Whether to keep the module alive. If True, the module will be kept 
alive
+        for the duration of the program until libtvm_ffi.so is unloaded.
+
     Returns
     -------
     The loaded module
@@ -459,4 +463,7 @@ def load_module(path: str | PathLike) -> Module:
 
     """
     path = fspath(path)
-    return _ffi_api.ModuleLoadFromFile(path)
+    mod = _ffi_api.ModuleLoadFromFile(path)
+    if keep_module_alive:
+        _ffi_api.ModuleGlobalsAdd(mod)
+    return mod
diff --git a/src/ffi/extra/module.cc b/src/ffi/extra/module.cc
index 0d3fd43..31bb95b 100644
--- a/src/ffi/extra/module.cc
+++ b/src/ffi/extra/module.cc
@@ -22,6 +22,7 @@
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
+#include <mutex>
 #include <unordered_set>
 #include <vector>
 
@@ -30,6 +31,33 @@
 namespace tvm {
 namespace ffi {
 
+/*!
+ * \brief Global modules, i.e. modules that are owned by the runtime and 
should not be unloaded.
+ * On the frontend, a module is added to the registry if `keep_alive=True` 
when `load_module` is
+ * called.
+ */
+class ModuleGlobals {
+ public:
+  void Add(const Module& m) {
+    std::scoped_lock<std::mutex> lock(mutex_);
+    modules_.Set(m, 1);
+  }
+
+  void Remove(const Module& m) {
+    std::scoped_lock<std::mutex> lock(mutex_);
+    modules_.erase(m);
+  }
+
+  static ModuleGlobals* Get() {
+    static ModuleGlobals instance;
+    return &instance;
+  }
+
+ private:
+  Map<Module, int> modules_;
+  std::mutex mutex_;
+};
+
 Optional<Function> ModuleObj::GetFunction(const String& name, bool 
query_imports) {
   if (auto opt_func = this->GetFunction(name)) {
     return opt_func;
@@ -161,7 +189,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats)
       .def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile)
       .def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule)
-      .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports);
+      .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports)
+      .def_method("ffi.ModuleGlobalsAdd", [](const Module& mod) { 
ModuleGlobals::Get()->Add(mod); })
+      .def_method("ffi.ModuleGlobalsRemove",
+                  [](const Module& mod) { ModuleGlobals::Get()->Remove(mod); 
});
 }
 }  // namespace ffi
 }  // namespace tvm
diff --git a/tests/python/test_build.py b/tests/python/test_build.py
index 7f22669..8844616 100644
--- a/tests/python/test_build.py
+++ b/tests/python/test_build.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import gc
 import pathlib
 
 import numpy
@@ -196,10 +195,6 @@ def test_build_inline_with_metadata() -> None:  # noqa: 
PLR0915
     schema = TypeSchema.from_json_str(metadata["type_schema"])
     assert str(schema) == "Callable[[Tensor, Tensor], None]"
 
-    # Explicitly cleanup all objects before module unload to avoid 
use-after-free
-    del metadata, schema, result, x, y, mod
-    gc.collect()
-
 
 def test_build_inline_with_docstrings() -> None:
     """Test building functions with documentation using the functions dict."""
@@ -279,10 +274,6 @@ def test_build_inline_with_docstrings() -> None:
     assert doc is not None, "divide should have documentation"
     assert doc == divide_docstring
 
-    # Explicitly cleanup all objects before module unload to avoid 
use-after-free
-    del metadata, schema, doc, result, mod
-    gc.collect()
-
 
 def test_build_without_metadata() -> None:
     """Test building without metadata export."""
diff --git a/tests/python/test_stl.py b/tests/python/test_stl.py
index 43037a6..3ba70b8 100644
--- a/tests/python/test_stl.py
+++ b/tests/python/test_stl.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import gc
 import pathlib
 
 import pytest
@@ -31,33 +30,21 @@ def test_stl() -> None:
 
     mod: Module = tvm_ffi.load_module(output_lib_path)
 
-    def run_check(mod: Module) -> None:
-        # This sub function is needed to make sure all temp variables 
deallocated
-        # before module unload since some of these objects contains deleters 
in the library
-        # code. If the module is unloaded before the object is deleted, the 
deleter
-        # may call an invalid address.
-        assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
-        assert mod.test_vector(None) is None
-        assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
-        assert mod.test_variant(1) == "int"
-        assert mod.test_variant(1.0) == "float"
-        assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
-        assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
-        assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
-        assert mod.test_function(lambda: 0)() == 1
-        assert mod.test_function(lambda: 10)() == 11
-
-        with pytest.raises(TypeError):
-            mod.test_tuple([1.5, 2.5])
-        with pytest.raises(TypeError):
-            mod.test_function(lambda: 0)(100)
-
-    run_check(mod)
-    # Force garbage collection to ensure that all objects created from the
-    # module are destroyed before the module itself is unloaded. This
-    # prevents segfaults from calling destructors in an unloaded library.
-    gc.collect()
-    del mod
+    assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
+    assert mod.test_vector(None) is None
+    assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
+    assert mod.test_variant(1) == "int"
+    assert mod.test_variant(1.0) == "float"
+    assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
+    assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+    assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+    assert mod.test_function(lambda: 0)() == 1
+    assert mod.test_function(lambda: 10)() == 11
+
+    with pytest.raises(TypeError):
+        mod.test_tuple([1.5, 2.5])
+    with pytest.raises(TypeError):
+        mod.test_function(lambda: 0)(100)
 
 
 if __name__ == "__main__":

Reply via email to