This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 759ee1236a [Support] Add Interrupt Handling in Pipe (#16255)
759ee1236a is described below
commit 759ee1236ad1b95fc6b9c356693a8f1062ceee13
Author: Lesheng Jin <[email protected]>
AuthorDate: Wed Dec 20 00:27:04 2023 +0800
[Support] Add Interrupt Handling in Pipe (#16255)
Co-authored-by: Sunghyun Park <[email protected]>
---
src/support/errno_handling.h | 69 ++++++++++++++++++++++++++++++++++++++++++++
src/support/pipe.h | 42 ++++++++++++++++++++-------
src/support/socket.h | 65 +++++++++++------------------------------
3 files changed, 117 insertions(+), 59 deletions(-)
diff --git a/src/support/errno_handling.h b/src/support/errno_handling.h
new file mode 100644
index 0000000000..0bdfdfdf02
--- /dev/null
+++ b/src/support/errno_handling.h
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file errno_handling.h
+ * \brief Common error number handling functions for socket.h and pipe.h
+ */
+#ifndef TVM_SUPPORT_ERRNO_HANDLING_H_
+#define TVM_SUPPORT_ERRNO_HANDLING_H_
+#include <errno.h>
+
+#include "ssize.h"
+
+namespace tvm {
+namespace support {
+/*!
+ * \brief Call a function and retry if an EINTR error is encountered.
+ *
+ * Socket operations can return EINTR when the interrupt handler
+ * is registered by the execution environment(e.g. python).
+ * We should retry if there is no KeyboardInterrupt recorded in
+ * the environment.
+ *
+ * \note This function is needed to avoid rare interrupt event
+ * in long running server code.
+ *
+ * \param func The function to retry.
+ * \return The return code returned by function f or error_value on retry
failure.
+ */
+template <typename FuncType, typename GetErrorCodeFuncType>
+inline ssize_t RetryCallOnEINTR(FuncType func, GetErrorCodeFuncType
fgeterrorcode) {
+ ssize_t ret = func();
+ // common path
+ if (ret != -1) return ret;
+ // less common path
+ do {
+ if (fgeterrorcode() == EINTR) {
+ // Call into env check signals to see if there are
+ // environment specific(e.g. python) signal exceptions.
+ // This function will throw an exception if there is
+ // if the process received a signal that requires TVM to return
immediately (e.g. SIGINT).
+ runtime::EnvCheckSignals();
+ } else {
+ // other errors
+ return ret;
+ }
+ ret = func();
+ } while (ret == -1);
+ return ret;
+}
+} // namespace support
+} // namespace tvm
+#endif // TVM_SUPPORT_ERRNO_HANDLING_H_
diff --git a/src/support/pipe.h b/src/support/pipe.h
index d869504dc4..557fe89e46 100644
--- a/src/support/pipe.h
+++ b/src/support/pipe.h
@@ -36,6 +36,7 @@
#include <cstdlib>
#include <cstring>
#endif
+#include "errno_handling.h"
namespace tvm {
namespace support {
@@ -52,8 +53,21 @@ class Pipe : public dmlc::Stream {
#endif
/*! \brief destructor */
~Pipe() { Flush(); }
+
using Stream::Read;
using Stream::Write;
+
+ /*!
+ * \return last error of pipe operation
+ */
+ static int GetLastErrorCode() {
+#ifdef _WIN32
+ return GetLastError();
+#else
+ return errno;
+#endif
+ }
+
/*!
* \brief reads data from a file descriptor
* \param ptr pointer to a memory buffer
@@ -63,12 +77,15 @@ class Pipe : public dmlc::Stream {
size_t Read(void* ptr, size_t size) final {
if (size == 0) return 0;
#ifdef _WIN32
- DWORD nread;
- ICHECK(ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
- << "Read Error: " << GetLastError();
+ auto fread = [&]() {
+ DWORD nread;
+ if (!ReadFile(handle_, static_cast<TCHAR*>(ptr), size, &nread, nullptr))
return -1;
+ return nread;
+ };
+ DWORD nread = static_cast<DWORD>(RetryCallOnEINTR(fread,
GetLastErrorCode));
+ ICHECK_EQ(static_cast<size_t>(nread), size) << "Read Error: " <<
GetLastError();
#else
- ssize_t nread;
- nread = read(handle_, ptr, size);
+ ssize_t nread = RetryCallOnEINTR([&]() { return read(handle_, ptr, size);
}, GetLastErrorCode);
ICHECK_GE(nread, 0) << "Write Error: " << strerror(errno);
#endif
return static_cast<size_t>(nread);
@@ -82,13 +99,16 @@ class Pipe : public dmlc::Stream {
void Write(const void* ptr, size_t size) final {
if (size == 0) return;
#ifdef _WIN32
- DWORD nwrite;
- ICHECK(WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite,
nullptr) &&
- static_cast<size_t>(nwrite) == size)
- << "Write Error: " << GetLastError();
+ auto fwrite = [&]() {
+ DWORD nwrite;
+ if (!WriteFile(handle_, static_cast<const TCHAR*>(ptr), size, &nwrite,
nullptr)) return -1;
+ return nwrite;
+ };
+ DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite,
GetLastErrorCode));
+ ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " <<
GetLastError();
#else
- ssize_t nwrite;
- nwrite = write(handle_, ptr, size);
+ ssize_t nwrite =
+ RetryCallOnEINTR([&]() { return write(handle_, ptr, size); },
GetLastErrorCode);
ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " <<
strerror(errno);
#endif
}
diff --git a/src/support/socket.h b/src/support/socket.h
index f62702bbc4..ac13cd3f2d 100644
--- a/src/support/socket.h
+++ b/src/support/socket.h
@@ -39,7 +39,6 @@
#endif
#else
#include <arpa/inet.h>
-#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
@@ -56,8 +55,9 @@
#include <unordered_map>
#include <vector>
-#include "../support/ssize.h"
-#include "../support/utils.h"
+#include "errno_handling.h"
+#include "ssize.h"
+#include "utils.h"
#if defined(_WIN32)
static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
@@ -310,7 +310,7 @@ class Socket {
/*!
* \return last error of socket operation
*/
- static int GetLastError() {
+ static int GetLastErrorCode() {
#ifdef _WIN32
return WSAGetLastError();
#else
@@ -319,7 +319,7 @@ class Socket {
}
/*! \return whether last error was would block */
static bool LastErrorWouldBlock() {
- int errsv = GetLastError();
+ int errsv = GetLastErrorCode();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
@@ -355,7 +355,7 @@ class Socket {
* \param msg The error message.
*/
static void Error(const char* msg) {
- int errsv = GetLastError();
+ int errsv = GetLastErrorCode();
#ifdef _WIN32
LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
#else
@@ -363,42 +363,6 @@ class Socket {
#endif
}
- /*!
- * \brief Call a function and retry if an EINTR error is encountered.
- *
- * Socket operations can return EINTR when the interrupt handler
- * is registered by the execution environment(e.g. python).
- * We should retry if there is no KeyboardInterrupt recorded in
- * the environment.
- *
- * \note This function is needed to avoid rare interrupt event
- * in long running server code.
- *
- * \param func The function to retry.
- * \return The return code returned by function f or error_value on retry
failure.
- */
- template <typename FuncType>
- ssize_t RetryCallOnEINTR(FuncType func) {
- ssize_t ret = func();
- // common path
- if (ret != -1) return ret;
- // less common path
- do {
- if (GetLastError() == EINTR) {
- // Call into env check signals to see if there are
- // environment specific(e.g. python) signal exceptions.
- // This function will throw an exception if there is
- // if the process received a signal that requires TVM to return
immediately (e.g. SIGINT).
- runtime::EnvCheckSignals();
- } else {
- // other errors
- return ret;
- }
- ret = func();
- } while (ret == -1);
- return ret;
- }
-
protected:
explicit Socket(SockType sockfd) : sockfd(sockfd) {}
};
@@ -445,7 +409,8 @@ class TCPSocket : public Socket {
* \return The accepted socket connection.
*/
TCPSocket Accept() {
- SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr,
nullptr); });
+ SockType newfd =
+ RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); },
GetLastErrorCode);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
@@ -459,7 +424,8 @@ class TCPSocket : public Socket {
TCPSocket Accept(SockAddr* addr) {
socklen_t addrlen = sizeof(addr->addr);
SockType newfd = RetryCallOnEINTR(
- [&]() { return accept(sockfd,
reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); });
+ [&]() { return accept(sockfd,
reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); },
+ GetLastErrorCode);
if (newfd == INVALID_SOCKET) {
Socket::Error("Accept");
}
@@ -500,7 +466,7 @@ class TCPSocket : public Socket {
ssize_t Send(const void* buf_, size_t len, int flag = 0) {
const char* buf = reinterpret_cast<const char*>(buf_);
return RetryCallOnEINTR(
- [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
});
+ [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag);
}, GetLastErrorCode);
}
/*!
* \brief receive data using the socket
@@ -513,7 +479,8 @@ class TCPSocket : public Socket {
ssize_t Recv(void* buf_, size_t len, int flags = 0) {
char* buf = reinterpret_cast<char*>(buf_);
return RetryCallOnEINTR(
- [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len),
flags); });
+ [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len),
flags); },
+ GetLastErrorCode);
}
/*!
* \brief perform block write that will attempt to send all data out
@@ -527,7 +494,8 @@ class TCPSocket : public Socket {
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = RetryCallOnEINTR(
- [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone),
0); });
+ [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone),
0); },
+ GetLastErrorCode);
if (ret == -1) {
if (LastErrorWouldBlock()) return ndone;
Socket::Error("SendAll");
@@ -549,7 +517,8 @@ class TCPSocket : public Socket {
size_t ndone = 0;
while (ndone < len) {
ssize_t ret = RetryCallOnEINTR(
- [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len -
ndone), MSG_WAITALL); });
+ [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len -
ndone), MSG_WAITALL); },
+ GetLastErrorCode);
if (ret == -1) {
if (LastErrorWouldBlock()) {
LOG(FATAL) << "would block";