This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 8dcaec1 feat: Add `keep_module_alive` to `load_module` (#334)
8dcaec1 is described below
commit 8dcaec1fb47bf7873b105385b7d2808d51f6b342
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 11 21:24:41 2025 -0800
feat: Add `keep_module_alive` to `load_module` (#334)
Should fix #264, #322. Supersedes #331.
This PR introduces a flag `keep_module_alive: bool = True` to the
following APIs:
- **`load_module`**
- **`cpp.load_inline`**
- **`cpp.load`**
By default, the flag is set to `True`, meaning it is bound to
`ModuleGlobals` inside `libtvm_ffi.so`. It retains a reference count to
those modules, and Python side will not unload the module.
Two private APIs are added:
- `tvm_ffi._ffi_api.ModuleGlobalsAdd`: adds a module to global registry
`ModuleGlobals`
- `tvm_ffi._ffi_api.ModuleGlobalsRemove`: removes the module from global
registry `ModuleGlobals`
---
python/tvm_ffi/_ffi_api.py | 4 +++
python/tvm_ffi/cpp/extension.py | 60 ++++++++++++++++++++++++++---------------
python/tvm_ffi/module.py | 11 ++++++--
src/ffi/extra/module.cc | 33 ++++++++++++++++++++++-
tests/python/test_build.py | 9 -------
tests/python/test_stl.py | 43 +++++++++++------------------
6 files changed, 99 insertions(+), 61 deletions(-)
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index ab6704a..4b49716 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -58,6 +58,8 @@ if TYPE_CHECKING:
def ModuleGetKind(_0: Module, /) -> str: ...
def ModuleGetPropertyMask(_0: Module, /) -> int: ...
def ModuleGetWriteFormats(_0: Module, /) -> Sequence[str]: ...
+ def ModuleGlobalsAdd(_0: Module, /) -> None: ...
+ def ModuleGlobalsRemove(_0: Module, /) -> None: ...
def ModuleImplementsFunction(_0: Module, _1: str, _2: bool, /) -> bool: ...
def ModuleImportModule(_0: Module, _1: Module, /) -> None: ...
def ModuleInspectSource(_0: Module, _1: str, /) -> str: ...
@@ -101,6 +103,8 @@ __all__ = [
"ModuleGetKind",
"ModuleGetPropertyMask",
"ModuleGetWriteFormats",
+ "ModuleGlobalsAdd",
+ "ModuleGlobalsRemove",
"ModuleImplementsFunction",
"ModuleImportModule",
"ModuleInspectSource",
diff --git a/python/tvm_ffi/cpp/extension.py b/python/tvm_ffi/cpp/extension.py
index b03ce67..761f0dd 100644
--- a/python/tvm_ffi/cpp/extension.py
+++ b/python/tvm_ffi/cpp/extension.py
@@ -751,7 +751,7 @@ def build_inline(
)
-def load_inline(
+def load_inline( # noqa: PLR0913
name: str,
*,
cpp_sources: Sequence[str] | str | None = None,
@@ -763,6 +763,7 @@ def load_inline(
extra_include_paths: Sequence[str] | None = None,
build_directory: str | None = None,
embed_cubin: Mapping[str, bytes] | None = None,
+ keep_module_alive: bool = True,
) -> Module:
"""Compile, build and load a C++/CUDA module from inline source code.
@@ -791,13 +792,13 @@ def load_inline(
Parameters
----------
- name: str
+ name
The name of the tvm ffi module.
- cpp_sources: Sequence[str] | str, optional
+ cpp_sources
The C++ source code. It can be a list of sources or a single source.
- cuda_sources: Sequence[str] | str, optional
+ cuda_sources
The CUDA source code. It can be a list of sources or a single source.
- functions: Mapping[str, str] | Sequence[str] | str, optional
+ functions
The functions in cpp_sources or cuda_source that will be exported to
the tvm ffi module. When a mapping is
given, the keys are the names of the exported functions, and the
values are docstrings for the functions
(use an empty string to skip documentation for specific functions).
When a sequence or a single string is given, they are
@@ -805,41 +806,48 @@ def load_inline(
also be given as a string. When cpp_sources is given, the functions
must be declared (not necessarily defined)
in the cpp_sources. When cpp_sources is not given, the functions must
be defined in the cuda_sources. If not
specified, no function will be exported.
- extra_cflags: Sequence[str], optional
+ extra_cflags
The extra compiler flags for C++ compilation.
The default flags are:
- On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
- On Windows: ['/std:c++17', '/O2']
- extra_cuda_cflags: Sequence[str], optional
+ extra_cuda_cflags
The extra compiler flags for CUDA compilation.
- extra_ldflags: Sequence[str], optional
+ extra_ldflags
The extra linker flags.
The default flags are:
- On Linux/macOS: ['-shared']
- On Windows: ['/DLL']
- extra_include_paths: Sequence[str], optional
+ extra_include_paths
The extra include paths.
- build_directory: str, optional
+ build_directory
The build directory. If not specified, a default tvm ffi cache
directory will be used. By default, the
cache directory is ``~/.cache/tvm-ffi``. You can also set the
``TVM_FFI_CACHE_DIR`` environment variable to
specify the cache directory.
- embed_cubin: Mapping[str, bytes], optional
+ embed_cubin
A mapping from CUBIN module names to CUBIN binary data. When provided,
the CUBIN data will be embedded
into the compiled shared library using objcopy, making it accessible
via the TVM_FFI_EMBED_CUBIN macro.
The keys should match the names used in TVM_FFI_EMBED_CUBIN calls in
the C++ source code.
+ keep_module_alive
+ Whether to keep the module alive. If True, the module will be kept
alive
+ for the duration of the program until libtvm_ffi.so is unloaded.
+
Returns
-------
mod: Module
The loaded tvm ffi module.
+ See Also
+ --------
+ :py:func:`tvm_ffi.load_module`
Example
-------
@@ -892,7 +900,8 @@ def load_inline(
extra_include_paths=extra_include_paths,
build_directory=build_directory,
embed_cubin=embed_cubin,
- )
+ ),
+ keep_module_alive=keep_module_alive,
)
@@ -1045,6 +1054,7 @@ def load(
extra_ldflags: Sequence[str] | None = None,
extra_include_paths: Sequence[str] | None = None,
build_directory: str | None = None,
+ keep_module_alive: bool = True,
) -> Module:
"""Compile, build and load a C++/CUDA module from source files.
@@ -1071,48 +1081,55 @@ def load(
Parameters
----------
- name: str
+ name
The name of the tvm ffi module.
- cpp_files: Sequence[str] | str, optional
+ cpp_files
The C++ source files to compile. It can be a list of file paths or a
single file path. Both absolute and
relative paths are supported.
- cuda_files: Sequence[str] | str, optional
+ cuda_files
The CUDA source files to compile. It can be a list of file paths or a
single file path. Both absolute and
relative paths are supported.
- extra_cflags: Sequence[str], optional
+ extra_cflags
The extra compiler flags for C++ compilation.
The default flags are:
- On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
- On Windows: ['/std:c++17', '/MD', '/O2']
- extra_cuda_cflags: Sequence[str], optional
+ extra_cuda_cflags
The extra compiler flags for CUDA compilation.
The default flags are:
- ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] (Linux/macOS)
- ['-Xcompiler', '/std:c++17', '/O2'] (Windows)
- extra_ldflags: Sequence[str], optional
+ extra_ldflags
The extra linker flags.
The default flags are:
- On Linux/macOS: ['-shared', '-L<tvm_ffi_lib_path>', '-ltvm_ffi']
- On Windows: ['/DLL', '/LIBPATH:<tvm_ffi_lib_path>',
'<tvm_ffi_lib_name>.lib']
- extra_include_paths: Sequence[str], optional
+ extra_include_paths
The extra include paths for header files. Both absolute and relative
paths are supported.
- build_directory: str, optional
+ build_directory
The build directory. If not specified, a default tvm ffi cache
directory will be used. By default, the
cache directory is ``~/.cache/tvm-ffi``. You can also set the
``TVM_FFI_CACHE_DIR`` environment variable to
specify the cache directory.
+ keep_module_alive
+ Whether to keep the module alive. If True, the module will be kept
alive
+ for the duration of the program until libtvm_ffi.so is unloaded.
+
Returns
-------
mod: Module
The loaded tvm ffi module.
+ See Also
+ --------
+ :py:func:`tvm_ffi.load_module`
Example
-------
@@ -1169,5 +1186,6 @@ def load(
extra_ldflags=extra_ldflags,
extra_include_paths=extra_include_paths,
build_directory=build_directory,
- )
+ ),
+ keep_module_alive=keep_module_alive,
)
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 7b548d6..0c04d21 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -427,7 +427,7 @@ def system_lib(symbol_prefix: str = "") -> Module:
return _ffi_api.SystemLib(symbol_prefix)
-def load_module(path: str | PathLike) -> Module:
+def load_module(path: str | PathLike, keep_module_alive: bool = True) ->
Module:
"""Load module from file.
Parameters
@@ -435,6 +435,10 @@ def load_module(path: str | PathLike) -> Module:
path
The path to the module file.
+ keep_module_alive
+ Whether to keep the module alive. If True, the module will be kept
alive
+ for the duration of the program until libtvm_ffi.so is unloaded.
+
Returns
-------
The loaded module
@@ -459,4 +463,7 @@ def load_module(path: str | PathLike) -> Module:
"""
path = fspath(path)
- return _ffi_api.ModuleLoadFromFile(path)
+ mod = _ffi_api.ModuleLoadFromFile(path)
+ if keep_module_alive:
+ _ffi_api.ModuleGlobalsAdd(mod)
+ return mod
diff --git a/src/ffi/extra/module.cc b/src/ffi/extra/module.cc
index 0d3fd43..31bb95b 100644
--- a/src/ffi/extra/module.cc
+++ b/src/ffi/extra/module.cc
@@ -22,6 +22,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <mutex>
#include <unordered_set>
#include <vector>
@@ -30,6 +31,33 @@
namespace tvm {
namespace ffi {
+/*!
+ * \brief Global modules, i.e. modules that are owned by the runtime and
should not be unloaded.
+ * On the frontend, a module is added to the registry if `keep_alive=True`
when `load_module` is
+ * called.
+ */
+class ModuleGlobals {
+ public:
+ void Add(const Module& m) {
+ std::scoped_lock<std::mutex> lock(mutex_);
+ modules_.Set(m, 1);
+ }
+
+ void Remove(const Module& m) {
+ std::scoped_lock<std::mutex> lock(mutex_);
+ modules_.erase(m);
+ }
+
+ static ModuleGlobals* Get() {
+ static ModuleGlobals instance;
+ return &instance;
+ }
+
+ private:
+ Map<Module, int> modules_;
+ std::mutex mutex_;
+};
+
Optional<Function> ModuleObj::GetFunction(const String& name, bool
query_imports) {
if (auto opt_func = this->GetFunction(name)) {
return opt_func;
@@ -161,7 +189,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_method("ffi.ModuleGetWriteFormats", &ModuleObj::GetWriteFormats)
.def_method("ffi.ModuleWriteToFile", &ModuleObj::WriteToFile)
.def_method("ffi.ModuleImportModule", &ModuleObj::ImportModule)
- .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports);
+ .def_method("ffi.ModuleClearImports", &ModuleObj::ClearImports)
+ .def_method("ffi.ModuleGlobalsAdd", [](const Module& mod) {
ModuleGlobals::Get()->Add(mod); })
+ .def_method("ffi.ModuleGlobalsRemove",
+ [](const Module& mod) { ModuleGlobals::Get()->Remove(mod);
});
}
} // namespace ffi
} // namespace tvm
diff --git a/tests/python/test_build.py b/tests/python/test_build.py
index 7f22669..8844616 100644
--- a/tests/python/test_build.py
+++ b/tests/python/test_build.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import gc
import pathlib
import numpy
@@ -196,10 +195,6 @@ def test_build_inline_with_metadata() -> None: # noqa:
PLR0915
schema = TypeSchema.from_json_str(metadata["type_schema"])
assert str(schema) == "Callable[[Tensor, Tensor], None]"
- # Explicitly cleanup all objects before module unload to avoid
use-after-free
- del metadata, schema, result, x, y, mod
- gc.collect()
-
def test_build_inline_with_docstrings() -> None:
"""Test building functions with documentation using the functions dict."""
@@ -279,10 +274,6 @@ def test_build_inline_with_docstrings() -> None:
assert doc is not None, "divide should have documentation"
assert doc == divide_docstring
- # Explicitly cleanup all objects before module unload to avoid
use-after-free
- del metadata, schema, doc, result, mod
- gc.collect()
-
def test_build_without_metadata() -> None:
"""Test building without metadata export."""
diff --git a/tests/python/test_stl.py b/tests/python/test_stl.py
index 43037a6..3ba70b8 100644
--- a/tests/python/test_stl.py
+++ b/tests/python/test_stl.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import gc
import pathlib
import pytest
@@ -31,33 +30,21 @@ def test_stl() -> None:
mod: Module = tvm_ffi.load_module(output_lib_path)
- def run_check(mod: Module) -> None:
- # This sub function is needed to make sure all temp variables
deallocated
- # before module unload since some of these objects contains deleters
in the library
- # code. If the module is unloaded before the object is deleted, the
deleter
- # may call an invalid address.
- assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
- assert mod.test_vector(None) is None
- assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
- assert mod.test_variant(1) == "int"
- assert mod.test_variant(1.0) == "float"
- assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
- assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
- assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
- assert mod.test_function(lambda: 0)() == 1
- assert mod.test_function(lambda: 10)() == 11
-
- with pytest.raises(TypeError):
- mod.test_tuple([1.5, 2.5])
- with pytest.raises(TypeError):
- mod.test_function(lambda: 0)(100)
-
- run_check(mod)
- # Force garbage collection to ensure that all objects created from the
- # module are destroyed before the module itself is unloaded. This
- # prevents segfaults from calling destructors in an unloaded library.
- gc.collect()
- del mod
+ assert list(mod.test_tuple([1, 2.5])) == [2.5, 1]
+ assert mod.test_vector(None) is None
+ assert list(mod.test_vector([[1, 2], [3, 4]])) == [3, 7]
+ assert mod.test_variant(1) == "int"
+ assert mod.test_variant(1.0) == "float"
+ assert list(mod.test_variant([1, 1.0])) == ["int", "float"]
+ assert dict(mod.test_map({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+ assert dict(mod.test_map_2({"a": 1, "b": 2})) == {1: "a", 2: "b"}
+ assert mod.test_function(lambda: 0)() == 1
+ assert mod.test_function(lambda: 10)() == 11
+
+ with pytest.raises(TypeError):
+ mod.test_tuple([1.5, 2.5])
+ with pytest.raises(TypeError):
+ mod.test_function(lambda: 0)(100)
if __name__ == "__main__":