This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 99e22328bf [Disco] Implement `Session.import_python_module` method 
(#16617)
99e22328bf is described below

commit 99e22328bf5c33d3c7f350ec41cb5aac9cfc69c4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Feb 26 04:05:33 2024 -0600

    [Disco] Implement `Session.import_python_module` method (#16617)
    
    Import a module into the workers.  If a python module has not yet been
    loaded, `Session.get_global_func` cannot load a packed func from it.
---
 python/tvm/runtime/__init__.py      |  1 +
 python/tvm/runtime/disco/session.py | 24 +++++++++++++++++++++++-
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index eccdcbad95..3a68c567ee 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -40,3 +40,4 @@ from .params import (
 )
 
 from . import executor
+from . import disco
diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index b166bd82e9..c54f646e17 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -21,7 +21,7 @@ from typing import Any, Callable, Optional, Sequence, Union
 
 import numpy as np
 
-from ..._ffi import register_object
+from ..._ffi import register_object, register_func
 from ..._ffi.runtime_ctypes import Device
 from ..container import ShapeTuple
 from ..ndarray import NDArray
@@ -153,6 +153,23 @@ class Session(Object):
         """
         return DPackedFunc(_ffi_api.SessionGetGlobalFunc(self, name), self)  # 
type: ignore # pylint: disable=no-member
 
+    def import_python_module(self, module_name: str) -> None:
+        """Import a python module in each worker
+
+        This may be required before call
+
+        Parameters
+        ----------
+        module_name: str
+
+            The python module name, as it would be used in a python
+            `import` statement.
+        """
+        if not hasattr(self, "_import_python_module"):
+            self._import_python_module = 
self.get_global_func("runtime.disco._import_python_module")
+
+        self._import_python_module(module_name)
+
     def call_packed(self, func: DRef, *args) -> DRef:
         """Call a PackedFunc on workers providing variadic arguments.
 
@@ -369,6 +386,11 @@ class ProcessSession(Session):
         )
 
 
+@register_func("runtime.disco._import_python_module")
+def _import_python_module(module_name: str) -> None:
+    __import__(module_name)
+
+
 REDUCE_OPS = {
     "sum": 0,
     "prod": 1,

Reply via email to