This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 683dfb0c04 [RPC] Report RPC Session Timeout to Client Instead of 
"kShutdown" (#15187)
683dfb0c04 is described below

commit 683dfb0c04d9f2296940e89c60c2277aca095ccd
Author: Qiang Zhang <[email protected]>
AuthorDate: Sun Jul 2 12:06:38 2023 +0800

    [RPC] Report RPC Session Timeout to Client Instead of "kShutdown" (#15187)
    
    By using RPC server in NPU board, at some time a compiled model will hang 
the NPU, because of the buggy operator libraries of NPU toolchain, so we must 
to use the session_timeout to ensure the board resource can be released by the 
hang jobs.
    
    Currently the handling of session timeout error in RPC server is not good, 
it just kill the server loop sub process, then in the destructor of  class 
`RPCEndpoint` will send the code of `kShutdown` to the RPC client, but the RPC 
client expect receive the code of `kReturn` or `kException`, so users will see 
the error message that like the one reported in  
https://github.com/apache/tvm/issues/15151, this error report will make users 
very confused and don't know what's happened.
    
    When using tuning to search a good schedule for operators, we only want to 
ignore the RPC session timeout error that indicate the schedule generated is an 
illegal one, but other error reported by the RPC server may help us find the 
potential bug of our tool chain built on top of TVM, so the RPC session timeout 
error should be split to a standalone TVM error class.
    
    This PR implemented these requirements by sending the RPC session timeout 
error message as a PRC server exception to the RPC client before kill the 
server loop sub process.
---
 python/tvm/error.py                       |  5 ++
 python/tvm/rpc/server.py                  | 88 ++++++++++++++++---------------
 src/runtime/rpc/rpc_endpoint.cc           |  8 ++-
 src/runtime/rpc/rpc_socket_impl.cc        | 34 ++++++++++++
 tests/python/unittest/test_runtime_rpc.py | 31 +++++++++++
 5 files changed, 121 insertions(+), 45 deletions(-)

diff --git a/python/tvm/error.py b/python/tvm/error.py
index afd079ca6b..cc0180593d 100644
--- a/python/tvm/error.py
+++ b/python/tvm/error.py
@@ -61,6 +61,11 @@ class RPCError(TVMError):
     """Error thrown by the remote server handling the RPC call."""
 
 
+@register_error
+class RPCSessionTimeoutError(RPCError, TimeoutError):
+    """Error thrown by the remote server when the RPC session has expired."""
+
+
 @register_error
 class OpError(TVMError):
     """Base class of all operator errors in frontends."""
diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index 5a2ad522eb..6ee683c73b 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -25,6 +25,7 @@ Server is TCP based with the following protocol:
    - {server|client}:device-type[:random-key] [-timeout=timeout]
 """
 # pylint: disable=invalid-name
+import os
 import ctypes
 import socket
 import select
@@ -118,16 +119,6 @@ def _server_env(load_library, work_path=None):
     return temp
 
 
-def _serve_loop(sock, addr, load_library, work_path=None):
-    """Server loop"""
-    sockfd = sock.fileno()
-    temp = _server_env(load_library, work_path)
-    _ffi_api.ServerLoop(sockfd)
-    if not work_path:
-        temp.remove()
-    logger.info("Finish serving %s", addr)
-
-
 def _parse_server_opt(opts):
     # parse client options
     ret = {}
@@ -137,6 +128,47 @@ def _parse_server_opt(opts):
     return ret
 
 
+def _serving(sock, addr, opts, load_library):
+    logger.info(f"connected from {addr}")
+    work_path = utils.tempdir()
+    old_cwd = os.getcwd()
+    os.chdir(work_path.path)  # Avoiding file name conflict between sessions.
+    logger.info(f"start serving at {work_path.path}")
+
+    def _serve_loop():
+        _server_env(load_library, work_path)
+        _ffi_api.ServerLoop(sock.fileno())
+
+    server_proc = multiprocessing.Process(target=_serve_loop)
+    server_proc.start()
+    server_proc.join(opts.get("timeout", None))  # Wait until finish or 
timeout.
+
+    if server_proc.is_alive():
+        logger.info("timeout in RPC session, kill..")
+        _ffi_api.ReturnException(
+            sock.fileno(),
+            f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has 
expired, '
+            f'try to increase the "session_timeout" value.',
+        )
+
+        try:
+            import psutil  # pylint: disable=import-outside-toplevel
+
+            # Terminate worker children firstly.
+            for child in 
psutil.Process(server_proc.pid).children(recursive=True):
+                child.terminate()
+        except ImportError:
+            # Don't dependent `psutil` hardly, because it isn't a pure Python
+            # package and maybe hard to be installed on some platforms.
+            pass
+        server_proc.terminate()
+
+    logger.info(f"finish serving {addr}")
+    os.chdir(old_cwd)
+    work_path.remove()
+    sock.close()
+
+
 def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
     """Listening loop of the server."""
 
@@ -237,30 +269,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, 
load_library, custom_addr):
             raise exc
 
         # step 3: serving
-        work_path = utils.tempdir()
-        logger.info("connection from %s", addr)
-        server_proc = multiprocessing.Process(
-            target=_serve_loop, args=(conn, addr, load_library, work_path)
-        )
-
-        server_proc.start()
-        # close from our side.
-        conn.close()
-        # wait until server process finish or timeout
-        server_proc.join(opts.get("timeout", None))
-
-        if server_proc.is_alive():
-            logger.info("Timeout in RPC session, kill..")
-            # pylint: disable=import-outside-toplevel
-            import psutil
-
-            parent = psutil.Process(server_proc.pid)
-            # terminate worker children
-            for child in parent.children(recursive=True):
-                child.terminate()
-            # terminate the worker
-            server_proc.terminate()
-        work_path.remove()
+        _serving(conn, addr, opts, load_library)
 
 
 def _connect_proxy_loop(addr, key, load_library):
@@ -285,15 +294,8 @@ def _connect_proxy_loop(addr, key, load_library):
                 raise RuntimeError(f"{str(addr)} is not RPC Proxy")
             keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
             remote_key = py_str(base.recvall(sock, keylen))
-            opts = _parse_server_opt(remote_key.split()[1:])
-            logger.info("connected to %s", str(addr))
-            process = multiprocessing.Process(target=_serve_loop, args=(sock, 
addr, load_library))
-            process.start()
-            sock.close()
-            process.join(opts.get("timeout", None))
-            if process.is_alive():
-                logger.info("Timeout in RPC session, kill..")
-                process.terminate()
+
+            _serving(sock, addr, _parse_server_opt(remote_key.split()[1:]), 
load_library)
             retry_count = 0
         except (socket.error, IOError) as err:
             retry_count += 1
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index 46710587ab..30606adf1b 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -40,6 +40,7 @@
 
 #include "../../support/arena.h"
 #include "../../support/ring_buffer.h"
+#include "../../support/utils.h"
 #include "../object_internal.h"
 #include "rpc_local_session.h"
 
@@ -372,8 +373,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     if (code == RPCCode::kException) {
       // switch to the state before sending exception.
       this->SwitchToState(kRecvPacketNumBytes);
-      std::string msg = args[0];
-      LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg;
+      String msg = args[0];
+      if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) {
+        msg = "RPCError: Error caught from RPC call:\n" + msg;
+      }
+      LOG(FATAL) << msg;
     }
 
     ICHECK(setreturn != nullptr) << "fsetreturn not available";
diff --git a/src/runtime/rpc/rpc_socket_impl.cc 
b/src/runtime/rpc/rpc_socket_impl.cc
index 3cc8cdc51f..1d0b5d5470 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -142,5 +142,39 @@ TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs 
args, TVMRetValue* rv)
   }
 });
 
+class SimpleSockHandler : public dmlc::Stream {
+  // Things that will interface with user directly.
+ public:
+  explicit SimpleSockHandler(int sockfd)
+      : sock_(static_cast<support::TCPSocket::SockType>(sockfd)) {}
+  using dmlc::Stream::Read;
+  using dmlc::Stream::ReadArray;
+  using dmlc::Stream::Write;
+  using dmlc::Stream::WriteArray;
+
+  // Unused here, implemented for microTVM framing layer.
+  void MessageStart(uint64_t packet_nbytes) {}
+  void MessageDone() {}
+
+  // Internal supporting.
+  // Override methods that inherited from dmlc::Stream.
+ private:
+  size_t Read(void* data, size_t size) final {
+    ICHECK_EQ(sock_.RecvAll(data, size), size);
+    return size;
+  }
+  void Write(const void* data, size_t size) final { 
ICHECK_EQ(sock_.SendAll(data, size), size); }
+
+  // Things of current class.
+ private:
+  support::TCPSocket sock_;
+};
+
+TVM_REGISTER_GLOBAL("rpc.ReturnException").set_body_typed([](int sockfd, 
String msg) {
+  auto handler = SimpleSockHandler(sockfd);
+  RPCReference::ReturnException(msg.c_str(), &handler);
+  return;
+});
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/unittest/test_runtime_rpc.py 
b/tests/python/unittest/test_runtime_rpc.py
index 97016684a6..de441948b1 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -606,3 +606,34 @@ def test_rpc_tracker_via_proxy(device_key):
     server1.terminate()
     proxy_server.terminate()
     tracker_server.terminate()
+
+
[email protected]_rpc
[email protected]("with_proxy", (True, False))
+def test_rpc_session_timeout_error(with_proxy):
+    port = 9000
+    port_end = 10000
+
+    tracker = Tracker(port=port, port_end=port_end)
+    time.sleep(0.5)
+    tracker_addr = (tracker.host, tracker.port)
+
+    if with_proxy:
+        proxy = Proxy(host="0.0.0.0", port=port, port_end=port_end, 
tracker_addr=tracker_addr)
+        time.sleep(0.5)
+        server = rpc.Server(host=proxy.host, port=proxy.port, is_proxy=True, 
key="x1")
+    else:
+        server = rpc.Server(port=port, port_end=port_end, 
tracker_addr=tracker_addr, key="x1")
+    time.sleep(0.5)
+
+    rpc_sess = rpc.connect_tracker(*tracker_addr).request(key="x1", 
session_timeout=1)
+
+    with pytest.raises(tvm.error.RPCSessionTimeoutError):
+        f1 = rpc_sess.get_function("rpc.test.addone")
+        time.sleep(2)
+        f1(10)
+
+    server.terminate()
+    if with_proxy:
+        proxy.terminate()
+    tracker.terminate()

Reply via email to