This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0bbaf0e  [Bug Fixed] Make query_rpc_tracker show the correct device 
server port and customized address (#8203)
0bbaf0e is described below

commit 0bbaf0eb985b0075fdcec8f27ec1b795067d0ea7
Author: Muyang Li <[email protected]>
AuthorDate: Sun Jun 20 22:59:48 2021 +0800

    [Bug Fixed] Make query_rpc_tracker show the correct device server port and 
customized address (#8203)
---
 python/tvm/rpc/server.py                  |  4 ++--
 python/tvm/rpc/tracker.py                 | 10 ++++++---
 tests/python/unittest/test_runtime_rpc.py | 37 +++++++++++++++++++++++++++----
 3 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index c07e88b..0b49b67 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -143,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, 
load_library, custom_addr):
         listen_sock: Socket
             The socket used by listening process.
 
-        tracker_conn : connnection to tracker
+        tracker_conn : connection to tracker
             Tracker connection
 
         ping_period : float, optional
@@ -216,7 +216,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, 
load_library, custom_addr):
                 if magic != base.RPC_TRACKER_MAGIC:
                     raise RuntimeError("%s is not RPC Tracker" % 
str(tracker_addr))
                 # report status of current queue
-                cinfo = {"key": "server:" + rpc_key}
+                cinfo = {"key": "server:" + rpc_key, "addr": (custom_addr, 
port)}
                 base.sendjson(tracker_conn, [TrackerCode.UPDATE_INFO, cinfo])
                 assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
 
diff --git a/python/tvm/rpc/tracker.py b/python/tvm/rpc/tracker.py
index 9506a52..c8ab15b 100644
--- a/python/tvm/rpc/tracker.py
+++ b/python/tvm/rpc/tracker.py
@@ -67,7 +67,7 @@ logger = logging.getLogger("RPCTracker")
 
 
 class Scheduler(object):
-    """Abstratc interface of scheduler."""
+    """Abstract interface of scheduler."""
 
     def put(self, value):
         """Push a resource into the scheduler.
@@ -167,7 +167,7 @@ class TCPEventHandler(tornado_util.TCPHandler):
         self._msg_size = 0
         self._addr = addr
         self._init_req_nbytes = 4
-        self._info = {"addr": addr}
+        self._info = {}
         # list of pending match keys that has not been used.
         self.pending_matchkeys = set()
         self._tracker._connections.add(self)
@@ -272,7 +272,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
             else:
                 self.ret_value(TrackerCode.FAIL)
         elif code == TrackerCode.UPDATE_INFO:
-            self._info.update(args[1])
+            info = args[1]
+            assert isinstance(info, dict)
+            if info["addr"][0] is None:
+                info["addr"][0] = self._addr[0]
+            self._info.update(info)
             self.ret_value(TrackerCode.SUCCESS)
         elif code == TrackerCode.SUMMARY:
             status = self._tracker.summary()
diff --git a/tests/python/unittest/test_runtime_rpc.py 
b/tests/python/unittest/test_runtime_rpc.py
index 1826451..f90c954 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -401,33 +401,62 @@ def test_rpc_tracker_register():
     # test registration
     tracker = Tracker(port=9000, port_end=10000)
     device_key = "test_device"
-    server = rpc.Server(
+    server1 = rpc.Server(
+        host="127.0.0.1",
+        port=9000,
+        port_end=10000,
+        key=device_key,
+        tracker_addr=("127.0.0.1", tracker.port),
+    )
+    server2 = rpc.Server(
+        host="127.0.0.1",
         port=9000,
         port_end=10000,
         key=device_key,
         tracker_addr=("127.0.0.1", tracker.port),
+        custom_addr="test_addr",  # this is a test address, which is unable to 
connect
     )
     time.sleep(1)
     client = rpc.connect_tracker("127.0.0.1", tracker.port)
 
+    def exist_address(summary, key, host, port):
+        server_info = summary["server_info"]
+        for device in server_info:
+            if device["key"] == "server:%s" % key:
+                addr = device["addr"]
+                if (host is None or host == addr[0]) and port == addr[1]:
+                    return True
+        return False
+
     summary = client.summary()
-    assert summary["queue_info"][device_key]["free"] == 1
+    assert summary["queue_info"][device_key]["free"] == 2
+    assert exist_address(summary, device_key, "127.0.0.1", server1.port)
+    assert exist_address(summary, device_key, "test_addr", server2.port)
 
     remote = client.request(device_key)
     summary = client.summary()
-    assert summary["queue_info"][device_key]["free"] == 0
+    assert summary["queue_info"][device_key]["free"] == 1
 
     del remote
     time.sleep(1)
 
     summary = client.summary()
+    assert summary["queue_info"][device_key]["free"] == 2
+
+    server1.terminate()
+    time.sleep(1)
+
+    summary = client.summary()
     assert summary["queue_info"][device_key]["free"] == 1
+    assert not exist_address(summary, device_key, "127.0.0.1", server1.port)
+    assert exist_address(summary, device_key, "test_addr", server2.port)
 
-    server.terminate()
+    server2.terminate()
     time.sleep(1)
 
     summary = client.summary()
     assert summary["queue_info"][device_key]["free"] == 0
+    assert not exist_address(summary, device_key, "test_addr", server2.port)
 
     tracker.terminate()
 

Reply via email to