This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 759ee1236a [Support] Add Interrupt Handling in Pipe (#16255)
759ee1236a is described below

commit 759ee1236ad1b95fc6b9c356693a8f1062ceee13
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Dec 20 00:27:04 2023 +0800

    [Support] Add Interrupt Handling in Pipe (#16255)
    
    
    
    Co-authored-by: Sunghyun Park <[email protected]>
---
 src/support/errno_handling.h | 69 ++++++++++++++++++++++++++++++++++++++++++++
 src/support/pipe.h           | 42 ++++++++++++++++++++-------
 src/support/socket.h         | 65 +++++++++++------------------------------
 3 files changed, 117 insertions(+), 59 deletions(-)

diff --git a/src/support/errno_handling.h b/src/support/errno_handling.h
new file mode 100644
index 0000000000..0bdfdfdf02
--- /dev/null
+++ b/src/support/errno_handling.h
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file errno_handling.h
+ * \brief Common error number handling functions for socket.h and pipe.h
+ */
+#ifndef TVM_SUPPORT_ERRNO_HANDLING_H_
+#define TVM_SUPPORT_ERRNO_HANDLING_H_
+#include <errno.h>
+
+#include "ssize.h"
+
+namespace tvm {
+namespace support {
+/*!
+ * \brief Call a function and retry if an EINTR error is encountered.
+ *
+ *  Socket operations can return EINTR when the interrupt handler
+ *  is registered by the execution environment(e.g. python).
+ *  We should retry if there is no KeyboardInterrupt recorded in
+ *  the environment.
+ *
+ * \note This function is needed to avoid rare interrupt event
+ *       in long running server code.
+ *
+ * \param func The function to retry.
+ * \return The return code returned by function f or error_value on retry 
failure.
+ */
+template <typename FuncType, typename GetErrorCodeFuncType>
+inline ssize_t RetryCallOnEINTR(FuncType func, GetErrorCodeFuncType 
fgeterrorcode) {
+  ssize_t ret = func();
+  // common path
+  if (ret != -1) return ret;
+  // less common path
+  do {
+    if (fgeterrorcode() == EINTR) {
+      // Call into env check signals to see if there are
+      // environment specific(e.g. python) signal exceptions.
+      // This function will throw an exception if there is
+      // if the process received a signal that requires TVM to return 
immediately (e.g. SIGINT).
+      runtime::EnvCheckSignals();
+    } else {
+      // other errors
+      return ret;
+    }
+    ret = func();
+  } while (ret == -1);
+  return ret;
+}
+}  // namespace support
+}  // namespace tvm
+#endif  // TVM_SUPPORT_ERRNO_HANDLING_H_
diff --git a/src/support/pipe.h b/src/support/pipe.h
index d869504dc4..557fe89e46 100644
--- a/src/support/pipe.h
+++ b/src/support/pipe.h
@@ -36,6 +36,7 @@
 #include <cstdlib>
 #include <cstring>
 #endif
+#include "errno_handling.h"
 
 namespace tvm {
 namespace support {
@@ -52,8 +53,21 @@ class Pipe : public dmlc::Stream {
 #endif
   /*! \brief destructor */
   ~Pipe() { Flush(); }
+
   using Stream::Read;
   using Stream::Write;
+
+  /*!
+   * \return last error of pipe operation
+   */
+  static int GetLastErrorCode() {
+#ifdef _WIN32
+    return GetLastError();
+#else
+    return errno;
+#endif
+  }
+
   /*!
    * \brief reads data from a file descriptor
    * \param ptr pointer to a memory buffer
@@ -63,12 +77,15 @@ class Pipe : public dmlc::Stream {
   size_t Read(void* ptr, size_t size) final {
     if (size == 0) return 0;
 #ifdef _WIN32
-    DWORD nread;
-    ICHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
-        << "Read Error: " << GetLastError();
+    auto fread = [&]() {
+      DWORD nread;
+      if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr)) 
return -1;
+      return nread;
+    };
+    DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread, 
GetLastErrorCode));
+    ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " << 
GetLastError();
 #else
-    ssize_t nread;
-    nread = read(handle_, ptr, size);
+    ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); 
}, GetLastErrorCode);
     ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno);
 #endif
     return static_cast<size_t>(nread);
@@ -82,13 +99,16 @@ class Pipe : public dmlc::Stream {
   void Write(const void* ptr, size_t size) final {
     if (size == 0) return;
 #ifdef _WIN32
-    DWORD nwrite;
-    ICHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, 
nullptr) &&
-           static_cast<size_t>(nwrite) == size)
-        << "Write Error: " << GetLastError();
+    auto fwrite = [&]() {
+      DWORD nwrite;
+      if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite, 
nullptr)) return -1;
+      return nwrite;
+    };
+    DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, 
GetLastErrorCode));
+    ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << 
GetLastError();
 #else
-    ssize_t nwrite;
-    nwrite = write(handle_, ptr, size);
+    ssize_t nwrite =
+        RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, 
GetLastErrorCode);
     ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << 
strerror(errno);
 #endif
   }
diff --git a/src/support/socket.h b/src/support/socket.h
index f62702bbc4..ac13cd3f2d 100644
--- a/src/support/socket.h
+++ b/src/support/socket.h
@@ -39,7 +39,6 @@
 #endif
 #else
 #include <arpa/inet.h>
-#include <errno.h>
 #include <fcntl.h>
 #include <netdb.h>
 #include <netinet/in.h>
@@ -56,8 +55,9 @@
 #include <unordered_map>
 #include <vector>
 
-#include "../support/ssize.h"
-#include "../support/utils.h"
+#include "errno_handling.h"
+#include "ssize.h"
+#include "utils.h"
 
 #if defined(_WIN32)
 static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
@@ -310,7 +310,7 @@ class Socket {
   /*!
    * \return last error of socket operation
    */
-  static int GetLastError() {
+  static int GetLastErrorCode() {
 #ifdef _WIN32
     return WSAGetLastError();
 #else
@@ -319,7 +319,7 @@ class Socket {
   }
   /*! \return whether last error was would block */
   static bool LastErrorWouldBlock() {
-    int errsv = GetLastError();
+    int errsv = GetLastErrorCode();
 #ifdef _WIN32
     return errsv == WSAEWOULDBLOCK;
 #else
@@ -355,7 +355,7 @@ class Socket {
    * \param msg The error message.
    */
   static void Error(const char* msg) {
-    int errsv = GetLastError();
+    int errsv = GetLastErrorCode();
 #ifdef _WIN32
     LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
 #else
@@ -363,42 +363,6 @@ class Socket {
 #endif
   }
 
-  /*!
-   * \brief Call a function and retry if an EINTR error is encountered.
-   *
-   *  Socket operations can return EINTR when the interrupt handler
-   *  is registered by the execution environment(e.g. python).
-   *  We should retry if there is no KeyboardInterrupt recorded in
-   *  the environment.
-   *
-   * \note This function is needed to avoid rare interrupt event
-   *       in long running server code.
-   *
-   * \param func The function to retry.
-   * \return The return code returned by function f or error_value on retry 
failure.
-   */
-  template <typename FuncType>
-  ssize_t RetryCallOnEINTR(FuncType func) {
-    ssize_t ret = func();
-    // common path
-    if (ret != -1) return ret;
-    // less common path
-    do {
-      if (GetLastError() == EINTR) {
-        // Call into env check signals to see if there are
-        // environment specific(e.g. python) signal exceptions.
-        // This function will throw an exception if there is
-        // if the process received a signal that requires TVM to return 
immediately (e.g. SIGINT).
-        runtime::EnvCheckSignals();
-      } else {
-        // other errors
-        return ret;
-      }
-      ret = func();
-    } while (ret == -1);
-    return ret;
-  }
-
  protected:
   explicit Socket(SockType sockfd) : sockfd(sockfd) {}
 };
@@ -445,7 +409,8 @@ class TCPSocket : public Socket {
    * \return The accepted socket connection.
    */
   TCPSocket Accept() {
-    SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, 
nullptr); });
+    SockType newfd =
+        RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }, 
GetLastErrorCode);
     if (newfd == INVALID_SOCKET) {
       Socket::Error("Accept");
     }
@@ -459,7 +424,8 @@ class TCPSocket : public Socket {
   TCPSocket Accept(SockAddr* addr) {
     socklen_t addrlen = sizeof(addr->addr);
     SockType newfd = RetryCallOnEINTR(
-        [&]() { return accept(sockfd, 
reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); });
+        [&]() { return accept(sockfd, 
reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); },
+        GetLastErrorCode);
     if (newfd == INVALID_SOCKET) {
       Socket::Error("Accept");
     }
@@ -500,7 +466,7 @@ class TCPSocket : public Socket {
   ssize_t Send(const void* buf_, size_t len, int flag = 0) {
     const char* buf = reinterpret_cast<const char*>(buf_);
     return RetryCallOnEINTR(
-        [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); 
});
+        [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); 
}, GetLastErrorCode);
   }
   /*!
    * \brief receive data using the socket
@@ -513,7 +479,8 @@ class TCPSocket : public Socket {
   ssize_t Recv(void* buf_, size_t len, int flags = 0) {
     char* buf = reinterpret_cast<char*>(buf_);
     return RetryCallOnEINTR(
-        [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), 
flags); });
+        [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), 
flags); },
+        GetLastErrorCode);
   }
   /*!
    * \brief perform block write that will attempt to send all data out
@@ -527,7 +494,8 @@ class TCPSocket : public Socket {
     size_t ndone = 0;
     while (ndone < len) {
       ssize_t ret = RetryCallOnEINTR(
-          [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 
0); });
+          [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 
0); },
+          GetLastErrorCode);
       if (ret == -1) {
         if (LastErrorWouldBlock()) return ndone;
         Socket::Error("SendAll");
@@ -549,7 +517,8 @@ class TCPSocket : public Socket {
     size_t ndone = 0;
     while (ndone < len) {
       ssize_t ret = RetryCallOnEINTR(
-          [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - 
ndone), MSG_WAITALL); });
+          [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - 
ndone), MSG_WAITALL); },
+          GetLastErrorCode);
       if (ret == -1) {
         if (LastErrorWouldBlock()) {
           LOG(FATAL) << "would block";

Reply via email to