This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0bbaf0e [Bug Fixed] Make query_rpc_tracker show the correct device
server port and customized address (#8203)
0bbaf0e is described below
commit 0bbaf0eb985b0075fdcec8f27ec1b795067d0ea7
Author: Muyang Li <[email protected]>
AuthorDate: Sun Jun 20 22:59:48 2021 +0800
[Bug Fixed] Make query_rpc_tracker show the correct device server port and
customized address (#8203)
---
python/tvm/rpc/server.py | 4 ++--
python/tvm/rpc/tracker.py | 10 ++++++---
tests/python/unittest/test_runtime_rpc.py | 37 +++++++++++++++++++++++++++----
3 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index c07e88b..0b49b67 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -143,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr,
load_library, custom_addr):
listen_sock: Socket
The socket used by listening process.
- tracker_conn : connnection to tracker
+ tracker_conn : connection to tracker
Tracker connection
ping_period : float, optional
@@ -216,7 +216,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr,
load_library, custom_addr):
if magic != base.RPC_TRACKER_MAGIC:
raise RuntimeError("%s is not RPC Tracker" %
str(tracker_addr))
# report status of current queue
- cinfo = {"key": "server:" + rpc_key}
+ cinfo = {"key": "server:" + rpc_key, "addr": (custom_addr,
port)}
base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 9506a52..c8ab15b 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -67,7 +67,7 @@ logger = logging.getLogger("RPCTracker")
class Scheduler(object):
- """Abstratc interface of scheduler."""
+ """Abstract interface of scheduler."""
def put(self, value):
"""Push a resource into the scheduler.
@@ -167,7 +167,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._msg_size = 0
self._addr = addr
self._init_req_nbytes = 4
- self._info = {"addr": addr}
+ self._info = {}
# list of pending match keys that has not been used.
self.pending_matchkeys = set()
self._tracker._connections.add(self)
@@ -272,7 +272,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
else:
self.ret_value(TrackerCode.FAIL)
elif code == TrackerCode.UPDATE_INFO:
- self._info.update(args[1])
+ info = args[1]
+ assert isinstance(info, dict)
+ if info["addr"][0] is None:
+ info["addr"][0] = self._addr[0]
+ self._info.update(info)
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.SUMMARY:
status = self._tracker.summary()
diff --git a/tests/python/unittest/test_runtime_rpc.py
b/tests/python/unittest/test_runtime_rpc.py
index 1826451..f90c954 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -401,33 +401,62 @@ def test_rpc_tracker_register():
# test registration
tracker = Tracker(port=9000, port_end=10000)
device_key = "test_device"
- server = rpc.Server(
+ server1 = rpc.Server(
+ host="127.0.0.1",
+ port=9000,
+ port_end=10000,
+ key=device_key,
+ tracker_addr=("127.0.0.1", tracker.port),
+ )
+ server2 = rpc.Server(
+ host="127.0.0.1",
port=9000,
port_end=10000,
key=device_key,
tracker_addr=("127.0.0.1", tracker.port),
+ custom_addr="test_addr", # this is a test address, which is unable to
connect
)
time.sleep(1)
client = rpc.connect_tracker("127.0.0.1", tracker.port)
+ def exist_address(summary, key, host, port):
+ server_info = summary["server_info"]
+ for device in server_info:
+ if device["key"] == "server:%s" % key:
+ addr = device["addr"]
+ if (host is None or host == addr[0]) and port == addr[1]:
+ return True
+ return False
+
summary = client.summary()
- assert summary["queue_info"][device_key]["free"] == 1
+ assert summary["queue_info"][device_key]["free"] == 2
+ assert exist_address(summary, device_key, "127.0.0.1", server1.port)
+ assert exist_address(summary, device_key, "test_addr", server2.port)
remote = client.request(device_key)
summary = client.summary()
- assert summary["queue_info"][device_key]["free"] == 0
+ assert summary["queue_info"][device_key]["free"] == 1
del remote
time.sleep(1)
summary = client.summary()
+ assert summary["queue_info"][device_key]["free"] == 2
+
+ server1.terminate()
+ time.sleep(1)
+
+ summary = client.summary()
assert summary["queue_info"][device_key]["free"] == 1
+ assert not exist_address(summary, device_key, "127.0.0.1", server1.port)
+ assert exist_address(summary, device_key, "test_addr", server2.port)
- server.terminate()
+ server2.terminate()
time.sleep(1)
summary = client.summary()
assert summary["queue_info"][device_key]["free"] == 0
+ assert not exist_address(summary, device_key, "test_addr", server2.port)
tracker.terminate()