This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 284faf2  [RPC] Make tracker jupyter friendly (#7961)
284faf2 is described below

commit 284faf241f173270064124d61a620503556860e7
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon May 3 16:41:02 2021 -0400

    [RPC] Make tracker jupyter friendly (#7961)
    
    This PR uses the PopenWorker to handle the tracker start up
    and makes the tracker jupyter friendly.
---
 python/tvm/contrib/popen_pool.py | 22 ++++++++--
 python/tvm/exec/rpc_tracker.py   | 24 -----------
 python/tvm/rpc/tracker.py        | 89 ++++++++++++++++++++++++++++------------
 3 files changed, 82 insertions(+), 53 deletions(-)

diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
index ecda995..2f55203 100644
--- a/python/tvm/contrib/popen_pool.py
+++ b/python/tvm/contrib/popen_pool.py
@@ -153,10 +153,26 @@ class PopenWorker:
         self._reader = os.fdopen(main_read, "rb")
         self._writer = os.fdopen(main_write, "wb")
 
-    def join(self):
-        """Join the current process worker before it terminates"""
+    def join(self, timeout=None):
+        """Join the current process worker before it terminates.
+
+        Parameters
+        ----------
+        timeout: Optional[number]
+            Timeout value, block at most timeout seconds if it
+            is a positive number.
+        """
+        if self._proc:
+            try:
+                self._proc.wait(timeout)
+            except subprocess.TimeoutExpired:
+                pass
+
+    def is_alive(self):
+        """Check if the process is alive"""
         if self._proc:
-            self._proc.wait()
+            return self._proc.poll() is None
+        return False
 
     def send(self, fn, args=(), kwargs=None, timeout=None):
         """Send a new function task fn(*args, **kwargs) to the subprocess.
diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py
index 05809e0..091e95a 100644
--- a/python/tvm/exec/rpc_tracker.py
+++ b/python/tvm/exec/rpc_tracker.py
@@ -16,12 +16,8 @@
 # under the License.
 # pylint: disable=redefined-outer-name, invalid-name
 """Tool to start RPC tracker"""
-from __future__ import absolute_import
-
 import logging
 import argparse
-import multiprocessing
-import sys
 from ..rpc.tracker import Tracker
 
 
@@ -38,27 +34,7 @@ if __name__ == "__main__":
     )
     parser.add_argument("--port", type=int, default=9190, help="The port of 
the RPC")
     parser.add_argument("--port-end", type=int, default=9199, help="The end 
search port of the RPC")
-    parser.add_argument(
-        "--no-fork",
-        dest="fork",
-        action="store_false",
-        help="Use spawn mode to avoid fork. This option \
-                         is able to avoid potential fork problems with Metal, 
OpenCL \
-                         and ROCM compilers.",
-    )
     parser.add_argument("--silent", action="store_true", help="Whether run in 
silent mode.")
-
-    parser.set_defaults(fork=True)
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
-    if args.fork is False:
-        if sys.version_info[0] < 3:
-            raise RuntimeError("Python3 is required for spawn mode.")
-        multiprocessing.set_start_method("spawn")
-    else:
-        if not args.silent:
-            logging.info(
-                "If you are running ROCM/Metal, fork will cause "
-                "compiler internal error. Try to launch with arg 
```--no-fork```"
-            )
     main(args)
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 9dc4139..25ff15c 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -41,14 +41,15 @@ List of available APIs:
 """
 # pylint: disable=invalid-name
 
+import asyncio
 import heapq
 import logging
 import socket
 import threading
-import multiprocessing
 import errno
 import struct
 import json
+from tvm.contrib.popen_pool import PopenWorker
 
 try:
     from tornado import ioloop
@@ -362,14 +363,55 @@ class TrackerServerHandler(object):
 
 
 def _tracker_server(listen_sock, stop_key):
+    asyncio.set_event_loop(asyncio.new_event_loop())
     handler = TrackerServerHandler(listen_sock, stop_key)
     handler.run()
 
 
+class PopenTrackerServerState(object):
+    """Internal PopenTrackerServer State"""
+
+    current = None
+
+    def __init__(self, host, port=9190, port_end=9199, silent=False):
+        if silent:
+            logger.setLevel(logging.WARN)
+
+        sock = socket.socket(base.get_addr_family((host, port)), 
socket.SOCK_STREAM)
+        self.port = None
+        self.stop_key = base.random_key("tracker")
+        for my_port in range(port, port_end):
+            try:
+                sock.bind((host, my_port))
+                self.port = my_port
+                break
+            except socket.error as sock_err:
+                if sock_err.errno in [errno.EADDRINUSE]:
+                    continue
+                raise sock_err
+        if not self.port:
+            raise ValueError("cannot bind to any port in [%d, %d)" % (port, 
port_end))
+        logger.info("bind to %s:%d", host, self.port)
+        sock.listen(1)
+        self.thread = threading.Thread(target=_tracker_server, args=(sock, 
self.stop_key))
+        self.thread.start()
+        self.host = host
+
+
+def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False):
+    # This is a function that will be sent to the
+    # Popen worker to run on a separate process.
+    # Create and start the server in a different thread
+    state = PopenTrackerServerState(host, port, port_end, silent)
+    PopenTrackerServerState.current = state
+    # returns the port so that the main can get the port number.
+    return (state.port, state.stop_key)
+
+
 class Tracker(object):
     """Start RPC tracker on a separate process.
 
-    Python implementation based on multi-processing.
+    Python implementation based on PopenWorker.
 
     Parameters
     ----------
@@ -389,28 +431,20 @@ class Tracker(object):
     def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False):
         if silent:
             logger.setLevel(logging.WARN)
-
-        sock = socket.socket(base.get_addr_family((host, port)), 
socket.SOCK_STREAM)
-        self.port = None
-        self.stop_key = base.random_key("tracker")
-        for my_port in range(port, port_end):
-            try:
-                sock.bind((host, my_port))
-                self.port = my_port
-                break
-            except socket.error as sock_err:
-                if sock_err.errno in [errno.EADDRINUSE]:
-                    continue
-                raise sock_err
-        if not self.port:
-            raise ValueError("cannot bind to any port in [%d, %d)" % (port, 
port_end))
-        logger.info("bind to %s:%d", host, self.port)
-        sock.listen(1)
-        self.proc = multiprocessing.Process(target=_tracker_server, 
args=(sock, self.stop_key))
-        self.proc.start()
+        self.proc = PopenWorker()
+        # send the function
+        self.proc.send(
+            _popen_start_tracker_server,
+            [
+                host,
+                port,
+                port_end,
+                silent,
+            ],
+        )
+        # receive the port
+        self.port, self.stop_key = self.proc.recv()
         self.host = host
-        # close the socket on this process
-        sock.close()
 
     def _stop_tracker(self):
         sock = socket.socket(base.get_addr_family((self.host, self.port)), 
socket.SOCK_STREAM)
@@ -427,11 +461,14 @@ class Tracker(object):
         if self.proc:
             if self.proc.is_alive():
                 self._stop_tracker()
-                self.proc.join(1)
+            self.proc.join(0.1)
             if self.proc.is_alive():
                 logger.info("Terminating Tracker Server...")
-                self.proc.terminate()
+                self.proc.kill()
             self.proc = None
 
     def __del__(self):
-        self.terminate()
+        try:
+            self.terminate()
+        except TypeError:
+            pass

Reply via email to