This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 284faf2 [RPC] Make tracker jupyter friendly (#7961)
284faf2 is described below
commit 284faf241f173270064124d61a620503556860e7
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon May 3 16:41:02 2021 -0400
[RPC] Make tracker jupyter friendly (#7961)
This PR uses the PopenWorker to handle the tracker start up
and makes the tracker jupyter friendly.
---
python/tvm/contrib/popen_pool.py | 22 ++++++++--
python/tvm/exec/rpc_tracker.py | 24 -----------
python/tvm/rpc/tracker.py | 89 ++++++++++++++++++++++++++++------------
3 files changed, 82 insertions(+), 53 deletions(-)
diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py
index ecda995..2f55203 100644
--- a/python/tvm/contrib/popen_pool.py
+++ b/python/tvm/contrib/popen_pool.py
@@ -153,10 +153,26 @@ class PopenWorker:
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")
- def join(self):
- """Join the current process worker before it terminates"""
+ def join(self, timeout=None):
+ """Join the current process worker before it terminates.
+
+ Parameters
+ ----------
+ timeout: Optional[number]
+ Timeout value, block at most timeout seconds if it
+ is a positive number.
+ """
+ if self._proc:
+ try:
+ self._proc.wait(timeout)
+ except subprocess.TimeoutExpired:
+ pass
+
+ def is_alive(self):
+ """Check if the process is alive"""
if self._proc:
- self._proc.wait()
+ return self._proc.poll() is None
+ return False
def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.
diff --git a/python/tvm/exec/rpc_tracker.py b/python/tvm/exec/rpc_tracker.py
index 05809e0..091e95a 100644
--- a/python/tvm/exec/rpc_tracker.py
+++ b/python/tvm/exec/rpc_tracker.py
@@ -16,12 +16,8 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Tool to start RPC tracker"""
-from __future__ import absolute_import
-
import logging
import argparse
-import multiprocessing
-import sys
from ..rpc.tracker import Tracker
@@ -38,27 +34,7 @@ if __name__ == "__main__":
)
parser.add_argument("--port", type=int, default=9190, help="The port of
the RPC")
parser.add_argument("--port-end", type=int, default=9199, help="The end
search port of the RPC")
- parser.add_argument(
- "--no-fork",
- dest="fork",
- action="store_false",
- help="Use spawn mode to avoid fork. This option \
- is able to avoid potential fork problems with Metal,
OpenCL \
- and ROCM compilers.",
- )
parser.add_argument("--silent", action="store_true", help="Whether run in
silent mode.")
-
- parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
- if args.fork is False:
- if sys.version_info[0] < 3:
- raise RuntimeError("Python3 is required for spawn mode.")
- multiprocessing.set_start_method("spawn")
- else:
- if not args.silent:
- logging.info(
- "If you are running ROCM/Metal, fork will cause "
- "compiler internal error. Try to launch with arg
```--no-fork```"
- )
main(args)
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 9dc4139..25ff15c 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -41,14 +41,15 @@ List of available APIs:
"""
# pylint: disable=invalid-name
+import asyncio
import heapq
import logging
import socket
import threading
-import multiprocessing
import errno
import struct
import json
+from tvm.contrib.popen_pool import PopenWorker
try:
from tornado import ioloop
@@ -362,14 +363,55 @@ class TrackerServerHandler(object):
def _tracker_server(listen_sock, stop_key):
+ asyncio.set_event_loop(asyncio.new_event_loop())
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()
+class PopenTrackerServerState(object):
+ """Internal PopenTrackerServer State"""
+
+ current = None
+
+ def __init__(self, host, port=9190, port_end=9199, silent=False):
+ if silent:
+ logger.setLevel(logging.WARN)
+
+ sock = socket.socket(base.get_addr_family((host, port)),
socket.SOCK_STREAM)
+ self.port = None
+ self.stop_key = base.random_key("tracker")
+ for my_port in range(port, port_end):
+ try:
+ sock.bind((host, my_port))
+ self.port = my_port
+ break
+ except socket.error as sock_err:
+ if sock_err.errno in [errno.EADDRINUSE]:
+ continue
+ raise sock_err
+ if not self.port:
+ raise ValueError("cannot bind to any port in [%d, %d)" % (port,
port_end))
+ logger.info("bind to %s:%d", host, self.port)
+ sock.listen(1)
+ self.thread = threading.Thread(target=_tracker_server, args=(sock,
self.stop_key))
+ self.thread.start()
+ self.host = host
+
+
+def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False):
+ # This is a function that will be sent to the
+ # Popen worker to run on a separate process.
+ # Create and start the server in a different thread
+ state = PopenTrackerServerState(host, port, port_end, silent)
+ PopenTrackerServerState.current = state
+ # returns the port so that the main can get the port number.
+ return (state.port, state.stop_key)
+
+
class Tracker(object):
"""Start RPC tracker on a separate process.
- Python implementation based on multi-processing.
+ Python implementation based on PopenWorker.
Parameters
----------
@@ -389,28 +431,20 @@ class Tracker(object):
def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)
-
- sock = socket.socket(base.get_addr_family((host, port)),
socket.SOCK_STREAM)
- self.port = None
- self.stop_key = base.random_key("tracker")
- for my_port in range(port, port_end):
- try:
- sock.bind((host, my_port))
- self.port = my_port
- break
- except socket.error as sock_err:
- if sock_err.errno in [errno.EADDRINUSE]:
- continue
- raise sock_err
- if not self.port:
- raise ValueError("cannot bind to any port in [%d, %d)" % (port,
port_end))
- logger.info("bind to %s:%d", host, self.port)
- sock.listen(1)
- self.proc = multiprocessing.Process(target=_tracker_server,
args=(sock, self.stop_key))
- self.proc.start()
+ self.proc = PopenWorker()
+ # send the function
+ self.proc.send(
+ _popen_start_tracker_server,
+ [
+ host,
+ port,
+ port_end,
+ silent,
+ ],
+ )
+ # receive the port
+ self.port, self.stop_key = self.proc.recv()
self.host = host
- # close the socket on this process
- sock.close()
def _stop_tracker(self):
sock = socket.socket(base.get_addr_family((self.host, self.port)),
socket.SOCK_STREAM)
@@ -427,11 +461,14 @@ class Tracker(object):
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
- self.proc.join(1)
+ self.proc.join(0.1)
if self.proc.is_alive():
logger.info("Terminating Tracker Server...")
- self.proc.terminate()
+ self.proc.kill()
self.proc = None
def __del__(self):
- self.terminate()
+ try:
+ self.terminate()
+ except TypeError:
+ pass