This is an automated email from the ASF dual-hosted git repository.

wwbmmm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new c666ea45 Support shutdown write and notify for success write (#2547)
c666ea45 is described below

commit c666ea45258c18180dff31a8c7def1a865f71a11
Author: Bright Chen <chenguangmin...@foxmail.com>
AuthorDate: Mon Apr 8 11:07:11 2024 +0800

    Support shutdown write and notify for success write (#2547)
    
    * Support shutdown write and notify for success write
    
    * Fix some problems
---
 src/brpc/controller.cpp       |   1 +
 src/brpc/errno.proto          |   1 +
 src/brpc/socket.cpp           | 130 +++++++++++++++++++------
 src/brpc/socket.h             |  91 ++++++++++++++++--
 test/brpc_socket_unittest.cpp | 215 +++++++++++++++++++++++++++++++++++++++++-
 5 files changed, 402 insertions(+), 36 deletions(-)

diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp
index 42f507fd..1d9b1bb9 100644
--- a/src/brpc/controller.cpp
+++ b/src/brpc/controller.cpp
@@ -80,6 +80,7 @@ BAIDU_REGISTER_ERRNO(brpc::ELOGOFF, "Server is stopping");
 BAIDU_REGISTER_ERRNO(brpc::ELIMIT, "Reached server's max_concurrency");
 BAIDU_REGISTER_ERRNO(brpc::ECLOSE, "Close socket initiatively");
 BAIDU_REGISTER_ERRNO(brpc::EITP, "Bad Itp response");
+BAIDU_REGISTER_ERRNO(brpc::ESHUTDOWNWRITE, "Shutdown write of socket");
 
 #if BRPC_WITH_RDMA
 BAIDU_REGISTER_ERRNO(brpc::ERDMA, "RDMA verbs error");
diff --git a/src/brpc/errno.proto b/src/brpc/errno.proto
index fccd8edb..26ffadc2 100644
--- a/src/brpc/errno.proto
+++ b/src/brpc/errno.proto
@@ -49,6 +49,7 @@ enum Errno {
     ELIMIT                  = 2004;  // Reached server's limit on resources
     ECLOSE                  = 2005;  // Close socket initiatively
     EITP                    = 2006;  // Failed Itp response
+    ESHUTDOWNWRITE          = 2007;  // Shutdown write of socket
 
     // Errno related to RDMA (may happen at both sides)
     ERDMA                   = 3001;  // RDMA verbs error
diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp
index c2e3095d..447bac5f 100644
--- a/src/brpc/socket.cpp
+++ b/src/brpc/socket.cpp
@@ -307,33 +307,56 @@ const uint32_t MAX_PIPELINED_COUNT = 16384;
 
 struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
     static WriteRequest* const UNCONNECTED;
-    
+
     butil::IOBuf data;
     WriteRequest* next;
     bthread_id_t id_wait;
-    Socket* socket;
+
+    void clear_and_set_control_bits(bool notify_on_success,
+                                    bool shutdown_write) {
+        _socket_and_control_bits.set_extra(
+            (uint16_t)notify_on_success << 1 | (uint16_t)shutdown_write);
+    }
+
+    void set_socket(Socket* s) {
+        _socket_and_control_bits.set(s);
+    }
+
+    // If this field is set to true, notify when write successfully.
+    bool is_notify_on_success() const {
+        return _socket_and_control_bits.extra() & ((uint16_t)1 << 1);
+    }
+
+    // Whether shutdown write of the socket after this write complete.
+    bool need_shutdown_write() const {
+        return _socket_and_control_bits.extra() & (uint16_t)1;
+    }
+
+    Socket* get_socket() const {
+        return _socket_and_control_bits.get();
+    }
     
     uint32_t pipelined_count() const {
-        return (_pc_and_udmsg >> 48) & 0x3FFF;
+        return _pc_and_udmsg.extra() & 0x3FFF;
     }
     uint32_t get_auth_flags() const {
-       return (_pc_and_udmsg >> 62) & 0x03;
+       return (_pc_and_udmsg.extra() >> 14) & 0x03;
     }
     void clear_pipelined_count_and_auth_flags() {
-        _pc_and_udmsg &= 0xFFFFFFFFFFFFULL;
+        _pc_and_udmsg.reset_extra();
     }
     SocketMessage* user_message() const {
-        return (SocketMessage*)(_pc_and_udmsg & 0xFFFFFFFFFFFFULL);
+        return _pc_and_udmsg.get();
     }
     void clear_user_message() {
-        _pc_and_udmsg &= 0xFFFF000000000000ULL;
+        _pc_and_udmsg.reset();
     }
     void set_pipelined_count_and_user_message(
         uint32_t pc, SocketMessage* msg, uint32_t auth_flags) {
         if (auth_flags) {
             pc |= (auth_flags & 0x03) << 14;
         }
-        _pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg;
+        _pc_and_udmsg.set_ptr_and_extra(msg, pc);
     }
 
     bool reset_pipelined_count_and_user_message() {
@@ -355,7 +378,10 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
     void Setup(Socket* s);
     
 private:
-    uint64_t _pc_and_udmsg;
+    // Socket pointer and some control bits.
+    PackedPtr<Socket> _socket_and_control_bits;
+    // User message pointer, pipelined count auth flag.
+    PackedPtr<SocketMessage> _pc_and_udmsg;
 };
 
 void Socket::WriteRequest::Setup(Socket* s) {
@@ -399,7 +425,7 @@ public:
     EpollOutRequest() : fd(-1), timer_id(0)
                       , on_epollout_event(NULL), data(NULL) {}
 
-    ~EpollOutRequest() {
+    ~EpollOutRequest() override {
         // Remove the timer at last inside destructor to avoid
         // race with the place that registers the timer
         if (timer_id) {
@@ -408,8 +434,8 @@ public:
         }
     }
     
-    void BeforeRecycle(Socket*) {
-        // Recycle itself
+    void BeforeRecycle(Socket*) override {
+        // Recycle itself.
         delete this;
     }
 
@@ -464,6 +490,7 @@ Socket::Socket(Forbidden)
     , _unwritten_bytes(0)
     , _epollout_butex(NULL)
     , _write_head(NULL)
+    , _is_write_shutdown(false)
     , _stream_set(NULL)
     , _total_streams_unconsumed_size(0)
     , _ninflight_app_health_check(0)
@@ -485,7 +512,11 @@ void 
Socket::ReturnSuccessfulWriteRequest(Socket::WriteRequest* p) {
     const bthread_id_t id_wait = p->id_wait;
     butil::return_object(p);
     if (id_wait != INVALID_BTHREAD_ID) {
-        NotifyOnFailed(id_wait);
+        if (p->is_notify_on_success() && !Failed()) {
+            bthread_id_error(id_wait, 0);
+        } else {
+            NotifyOnFailed(id_wait);
+        }
     }
 }
 
@@ -514,11 +545,18 @@ Socket::WriteRequest* 
Socket::ReleaseWriteRequestsExceptLast(
 }
 
 void Socket::ReleaseAllFailedWriteRequests(Socket::WriteRequest* req) {
-    CHECK(Failed());
-    pthread_mutex_lock(&_id_wait_list_mutex);
-    const int error_code = non_zero_error_code();
-    const std::string error_text = _error_text;
-    pthread_mutex_unlock(&_id_wait_list_mutex);
+    CHECK(Failed() || IsWriteShutdown());
+    int error_code;
+    std::string error_text;
+    if (Failed()) {
+        pthread_mutex_lock(&_id_wait_list_mutex);
+        error_code = non_zero_error_code();
+        error_text = _error_text;
+        pthread_mutex_unlock(&_id_wait_list_mutex);
+    } else {
+        error_code = ESHUTDOWNWRITE;
+        error_text = "Shutdown write of the socket";
+    }
     // Notice that `req' is not tail if Address after IsWriteComplete fails.
     do {
         req = ReleaseWriteRequestsExceptLast(req, error_code, error_text);
@@ -746,6 +784,7 @@ int Socket::Create(const SocketOptions& options, SocketId* 
id) {
     m->_keepalive_options = options.keepalive_options;
     m->_bthread_tag = options.bthread_tag;
     CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed));
+    m->_is_write_shutdown = false;
     // Must be last one! Internal fields of this Socket may be access
     // just after calling ResetFileDescriptor.
     if (m->ResetFileDescriptor(options.fd) != 0) {
@@ -1382,7 +1421,7 @@ int Socket::ConnectIfNot(const timespec* abstime, 
WriteRequest* req) {
     // Have to hold a reference for `req'
     SocketUniquePtr s;
     ReAddress(&s);
-    req->socket = s.get();
+    req->set_socket(s.get());
     if (_conn) {
         if (_conn->Connect(this, abstime, KeepWriteIfConnected, req) < 0) {
             return -1;
@@ -1454,7 +1493,7 @@ int Socket::HandleEpollOutRequest(int error_code, 
EpollOutRequest* req) {
 void Socket::AfterAppConnected(int err, void* data) {
     WriteRequest* req = static_cast<WriteRequest*>(data);
     if (err == 0) {
-        Socket* const s = req->socket;
+        Socket* const s = req->get_socket();
         SharedPart* sp = s->GetSharedPart();
         if (sp) {
             sp->num_continuous_connect_timeouts.store(0, 
butil::memory_order_relaxed);
@@ -1468,7 +1507,7 @@ void Socket::AfterAppConnected(int err, void* data) {
             KeepWrite(req);
         }
     } else {
-        SocketUniquePtr s(req->socket);
+        SocketUniquePtr s(req->get_socket());
         if (err == ETIMEDOUT) {
             SharedPart* sp = s->GetOrNewSharedPart();
             if (sp->num_continuous_connect_timeouts.fetch_add(
@@ -1496,7 +1535,7 @@ static void* RunClosure(void* arg) {
 
 int Socket::KeepWriteIfConnected(int fd, int err, void* data) {
     WriteRequest* req = static_cast<WriteRequest*>(data);
-    Socket* s = req->socket;
+    Socket* s = req->get_socket();
     if (err == 0 && s->ssl_state() == SSL_CONNECTING) {
         // Run ssl connect in a new bthread to avoid blocking
         // the current bthread (thus blocking the EventDispatcher)
@@ -1519,7 +1558,7 @@ int Socket::KeepWriteIfConnected(int fd, int err, void* 
data) {
 void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) {
     butil::fd_guard sockfd(fd);
     WriteRequest* req = static_cast<WriteRequest*>(data);
-    Socket* s = req->socket;
+    Socket* s = req->get_socket();
     CHECK_GE(sockfd, 0);
     if (err == 0 && s->CheckConnected(sockfd) == 0
         && s->ResetFileDescriptor(sockfd) == 0) {
@@ -1527,7 +1566,8 @@ void Socket::CheckConnectedAndKeepWrite(int fd, int err, 
void* data) {
             g_vars->channel_conn << 1;
         }
         if (s->_app_connect) {
-            s->_app_connect->StartConnect(req->socket, AfterAppConnected, req);
+            s->_app_connect->StartConnect(req->get_socket(),
+                                          AfterAppConnected, req);
         } else {
             // Successfully created a connection
             AfterAppConnected(0, req);
@@ -1614,6 +1654,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* 
options_in) {
     // wait until it points to a valid WriteRequest or NULL.
     req->next = WriteRequest::UNCONNECTED;
     req->id_wait = opt.id_wait;
+    req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
     req->set_pipelined_count_and_user_message(
         opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags);
     return StartWrite(req, opt);
@@ -1650,7 +1691,9 @@ int Socket::Write(SocketMessagePtr<>& msg, const 
WriteOptions* options_in) {
     // wait until it points to a valid WriteRequest or NULL.
     req->next = WriteRequest::UNCONNECTED;
     req->id_wait = opt.id_wait;
-    req->set_pipelined_count_and_user_message(opt.pipelined_count, 
msg.release(), opt.auth_flags);
+    req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
+    req->set_pipelined_count_and_user_message(
+        opt.pipelined_count, msg.release(), opt.auth_flags);
     return StartWrite(req, opt);
 }
 
@@ -1672,12 +1715,19 @@ int Socket::StartWrite(WriteRequest* req, const 
WriteOptions& opt) {
     bthread_t th;
     SocketUniquePtr ptr_for_keep_write;
     ssize_t nw = 0;
+    int ret = 0;
 
     // We've got the right to write.
     req->next = NULL;
+
+    // Fast fail when write has been shutdown.
+    if (_is_write_shutdown) {
+        goto FAIL_TO_WRITE;
+    }
+    _is_write_shutdown = req->need_shutdown_write();
     
     // Connect to remote_side() if not.
-    int ret = ConnectIfNot(opt.abstime, req);
+    ret = ConnectIfNot(opt.abstime, req);
     if (ret < 0) {
         saved_errno = errno;
         SetFailed(errno, "Fail to connect %s directly: %m", 
description().c_str());
@@ -1736,7 +1786,7 @@ int Socket::StartWrite(WriteRequest* req, const 
WriteOptions& opt) {
 
 KEEPWRITE_IN_BACKGROUND:
     ReAddress(&ptr_for_keep_write);
-    req->socket = ptr_for_keep_write.release();
+    req->set_socket(ptr_for_keep_write.release());
     if (bthread_start_background(&th, &BTHREAD_ATTR_NORMAL,
                                  KeepWrite, req) != 0) {
         LOG(FATAL) << "Fail to start KeepWrite";
@@ -1758,7 +1808,7 @@ static const size_t DATA_LIST_MAX = 256;
 void* Socket::KeepWrite(void* void_arg) {
     g_vars->nkeepwrite << 1;
     WriteRequest* req = static_cast<WriteRequest*>(void_arg);
-    SocketUniquePtr s(req->socket);
+    SocketUniquePtr s(req->get_socket());
 
     // When error occurs, spin until there's no more requests instead of
     // returning directly otherwise _write_head is permantly non-NULL which
@@ -1766,11 +1816,18 @@ void* Socket::KeepWrite(void* void_arg) {
     WriteRequest* cur_tail = NULL;
     do {
         // req was written, skip it.
+        bool need_shutdown = false;
         if (req->next != NULL && req->data.empty()) {
             WriteRequest* const saved_req = req;
+            need_shutdown = req->need_shutdown_write();
             req = req->next;
             s->ReturnSuccessfulWriteRequest(saved_req);
         }
+        if (need_shutdown) {
+            LOG(WARNING) << "Shutdown write of " << *s;
+            break;
+        }
+
         const ssize_t nw = s->DoWrite(req);
         if (nw < 0) {
             if (errno != EAGAIN && errno != EOVERCROWDED) {
@@ -1783,11 +1840,19 @@ void* Socket::KeepWrite(void* void_arg) {
         } else {
             s->AddOutputBytes(nw);
         }
-        // Release WriteRequest until non-empty data or last request.
+        // Release WriteRequest until non-empty data, last request or shutdown 
write.
         while (req->next != NULL && req->data.empty()) {
             WriteRequest* const saved_req = req;
+            need_shutdown = req->need_shutdown_write();
             req = req->next;
             s->ReturnSuccessfulWriteRequest(saved_req);
+            if (need_shutdown) {
+                break;
+            }
+        }
+        if (need_shutdown) {
+            LOG(WARNING) << "Shutdown write of " << *s;
+            break;
         }
         // TODO(gejun): wait for epollout when we actually have written
         // all the data. This weird heuristic reduces 30us delay...
@@ -1867,6 +1932,11 @@ ssize_t Socket::DoWrite(WriteRequest* req) {
     for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX;
          p = p->next) {
         data_list[ndata++] = &p->data;
+        if (p->need_shutdown_write()) {
+            // Write WriteRequest until shutdown write.
+            _is_write_shutdown = true;
+            break;
+        }
     }
 
     if (ssl_state() == SSL_OFF) {
@@ -2387,6 +2457,8 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) {
         os << "\n}";
     }
 
+    os << "\nis_wirte_shutdown=" << ptr->_is_write_shutdown;
+
     {
         int keepalive = 0;
         socklen_t len = sizeof(keepalive);
diff --git a/src/brpc/socket.h b/src/brpc/socket.h
index faf6baac..97ce5685 100644
--- a/src/brpc/socket.h
+++ b/src/brpc/socket.h
@@ -166,6 +166,62 @@ struct PipelinedInfo {
     bthread_id_t id_wait;
 };
 
+// A data structure packed with a pointer and
+// some extra information using a uint64 variable.
+template <class T>
+class PackedPtr {
+    static constexpr uint8_t MAX_POINTER_LEN = 48;
+    static constexpr uint64_t POINTER_MASK = ((uint64_t)1 << MAX_POINTER_LEN) 
- 1;
+    static constexpr uint64_t EXTRA_MASK = ~POINTER_MASK;
+public:
+    PackedPtr() : _data(0) {
+        BAIDU_CASSERT(sizeof(PackedPtr) == 8, sizeof_packed_ptr_must_be_8);
+    }
+
+    void set(T* ptr) {
+        // Clear the low 48 bits and then
+        // store the pointer in the low 48 bits.
+        _data = (_data & EXTRA_MASK) |
+                ((uint64_t)(uintptr_t)ptr & POINTER_MASK);
+    }
+
+    void reset() {
+        // Clear the low 48 bits.
+        _data &= EXTRA_MASK;
+    }
+
+    T* get() const { return (T*)(_data & POINTER_MASK); }
+
+    void set_extra(uint16_t extra) {
+        // Clear the high 16 bits and then
+        // store the extra in the high 16 bits.
+        _data = (_data & POINTER_MASK) |
+                ((uint64_t)extra << MAX_POINTER_LEN);
+    }
+
+    void reset_extra() {
+        // Clear the high 16 bits.
+        _data &= POINTER_MASK;
+    }
+
+    uint16_t extra() const { return _data >> MAX_POINTER_LEN; }
+
+    void set_ptr_and_extra(T* p, uint16_t extra) {
+        _data = ((uint64_t)(uintptr_t)p & POINTER_MASK) |
+                ((uint64_t)extra << MAX_POINTER_LEN);
+    }
+
+    void reset_ptr_and_extra() {
+        _data = 0;
+    }
+
+private:
+    // Pointer is stored in the low 48 bits,
+    // extra information is stored in the high 16 bits.
+    uint64_t _data;
+};
+
+
 struct SocketSSLContext {
     SocketSSLContext();
     ~SocketSSLContext();
@@ -269,11 +325,18 @@ public:
     // - Write once when uncontended(most cases).
     // - Wait-free when contended.
     struct WriteOptions {
-        // `id_wait' is signalled when this Socket is SetFailed. To disable
-        // the signal, set this field to INVALID_BTHREAD_ID.
-        // `on_reset' of `id_wait' is NOT called when Write() returns non-zero.
+        // `id_wait' is signalled when this Socket is SetFailed or data is 
written
+        // successfully with `notify_on_success=true'. To disable the signal, 
set
+        // this field to INVALID_BTHREAD_ID. `on_reset' of `id_wait' is NOT 
called
+        // when Write() returns non-zero.
         // Default: INVALID_BTHREAD_ID
         bthread_id_t id_wait;
+
+        // If this field is set to true and `id_wait' is not 
INVALID_BTHREAD_ID,
+        // `id_wait' can be signalled when write successfully.
+        // Default: false
+        bool notify_on_success;
+
         // If no connection exists, a connection will be established to
         // remote_side() regarding deadline `abstime'. NULL means no timeout.
         // Default: NULL
@@ -301,13 +364,27 @@ public:
         // performance. Otherwise, each write only writes one `msg` into socket
         // and no KeepWrite thread can be created, which brings poor
         // performance.
+        // Default: false
         bool write_in_background;
 
+        // After this write complete, shutdown write of the socket.
+        // Default: false
+        bool shutdown_write;
+
         WriteOptions()
-            : id_wait(INVALID_BTHREAD_ID), abstime(NULL)
-            , pipelined_count(0), auth_flags(0)
-            , ignore_eovercrowded(false), write_in_background(false) {}
+            : id_wait(INVALID_BTHREAD_ID)
+            , notify_on_success(false)
+            , abstime(NULL)
+            , pipelined_count(0)
+            , auth_flags(0)
+            , ignore_eovercrowded(false)
+            , write_in_background(false)
+            , shutdown_write(false) {}
     };
+
+    // True if write of socket is shutdown.
+    bool IsWriteShutdown() const { return _is_write_shutdown; }
+
     int Write(butil::IOBuf *msg, const WriteOptions* options = NULL);
 
     // Write an user-defined message. `msg' is released when Write() is
@@ -917,6 +994,8 @@ private:
     // Storing data that are not flushed into `fd' yet.
     butil::atomic<WriteRequest*> _write_head;
 
+    bool _is_write_shutdown;
+
     butil::Mutex _stream_mutex;
     std::set<StreamId> *_stream_set;
     butil::atomic<int64_t> _total_streams_unconsumed_size;
diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp
index d2258735..f278c46b 100644
--- a/test/brpc_socket_unittest.cpp
+++ b/test/brpc_socket_unittest.cpp
@@ -1225,7 +1225,6 @@ TEST_F(SocketTest, keepalive) {
     }
 }
 
-
 TEST_F(SocketTest, keepalive_input_message) {
     int default_keepalive = 0;
     int default_keepalive_idle = 0;
@@ -1418,3 +1417,217 @@ TEST_F(SocketTest, keepalive_input_message) {
         sockfd.release();
     }
 }
+
+int HandleSocketSuccessWrite(bthread_id_t id, void* data, int error_code,
+    const std::string& error_text) {
+    auto success_count = static_cast<size_t*>(data);
+    EXPECT_NE(nullptr, success_count);
+    EXPECT_EQ(0, error_code);
+    ++(*success_count);
+    CHECK_EQ(0, bthread_id_unlock_and_destroy(id));
+    return 0;
+}
+
+TEST_F(SocketTest, notify_on_success) {
+    const size_t REP = 10000;
+    int fds[2];
+    ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
+
+    brpc::SocketId id = 8888;
+    butil::EndPoint dummy;
+    ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy));
+    brpc::SocketOptions options;
+    options.fd = fds[1];
+    options.remote_side = dummy;
+    options.user = new CheckRecycle;
+    ASSERT_EQ(0, brpc::Socket::Create(options, &id));
+    brpc::SocketUniquePtr s;
+    ASSERT_EQ(0, brpc::Socket::Address(id, &s));
+    s->_ssl_state = brpc::SSL_OFF;
+    ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref));
+    global_sock = s.get();
+    ASSERT_TRUE(s.get());
+    ASSERT_EQ(fds[1], s->fd());
+    ASSERT_EQ(dummy, s->remote_side());
+    ASSERT_EQ(id, s->id());
+
+    pthread_t rth;
+    ReaderArg reader_arg = { fds[0], 0 };
+    pthread_create(&rth, NULL, reader, &reader_arg);
+
+    size_t success_count = 0;
+    char buf[] = "hello reader side!";
+    for (size_t c = 0; c < REP; ++c) {
+        bthread_id_t write_id;
+        ASSERT_EQ(0, bthread_id_create2(&write_id, &success_count,
+            HandleSocketSuccessWrite));
+        brpc::Socket::WriteOptions wopt;
+        wopt.id_wait = write_id;
+        wopt.notify_on_success = true;
+        butil::IOBuf src;
+        src.append(buf, 16);
+        if (s->Write(&src, &wopt) != 0) {
+            if (errno == brpc::EOVERCROWDED) {
+                // The buf is full, sleep a while and retry.
+                bthread_usleep(1000);
+                --c;
+                continue;
+            }
+            PLOG(ERROR) << "Fail to write into SocketId=" << id;
+            break;
+        }
+    }
+    bthread_usleep(1000 * 1000);
+
+    ASSERT_EQ(0, s->SetFailed());
+    s.release()->Dereference();
+    pthread_join(rth, NULL);
+    ASSERT_EQ(REP, success_count);
+    ASSERT_EQ((brpc::Socket*)NULL, global_sock);
+    close(fds[0]);
+}
+
+struct ShutdownWriterArg {
+    size_t times;
+    brpc::SocketId socket_id;
+    butil::atomic<int> total_count;
+    butil::atomic<int> success_count;
+};
+
+int HandleSocketShutdownWrite(bthread_id_t id, void* data, int error_code,
+    const std::string& error_text) {
+    auto arg = static_cast<ShutdownWriterArg*>(data);
+    EXPECT_NE(nullptr, arg);
+    EXPECT_TRUE(0 == error_code || brpc::ESHUTDOWNWRITE == error_code) << 
error_code;
+    ++arg->total_count;
+    if (0 == error_code) {
+        ++arg->success_count;
+    }
+    CHECK_EQ(0, bthread_id_unlock_and_destroy(id));
+    return 0;
+}
+
+void* ShutdownWriter(void* void_arg) {
+    auto arg = static_cast<ShutdownWriterArg*>(void_arg);
+    brpc::SocketUniquePtr sock;
+    if (brpc::Socket::Address(arg->socket_id, &sock) < 0) {
+        LOG(INFO) << "Fail to address SocketId=" << arg->socket_id;
+        return NULL;
+    }
+    for (size_t c = 0; c < arg->times; ++c) {
+        bthread_id_t write_id;
+        EXPECT_EQ(0, bthread_id_create2(&write_id, arg,
+            HandleSocketShutdownWrite));
+        brpc::Socket::WriteOptions wopt;
+        wopt.id_wait = write_id;
+        wopt.notify_on_success = true;
+        wopt.shutdown_write = true;
+        butil::IOBuf src;
+        src.push_back('a');
+        if (sock->Write(&src, &wopt) != 0) {
+            if (errno == brpc::EOVERCROWDED) {
+                // The buf is full, sleep a while and retry.
+                bthread_usleep(1000);
+                --c;
+                continue;
+            }
+        }
+    }
+    return NULL;
+}
+
+void TestShutdownWrite() {
+    const size_t REP = 100;
+    int fds[2];
+    ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
+
+    brpc::SocketId id = 8888;
+    butil::EndPoint dummy;
+    ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy));
+    brpc::SocketOptions options;
+    options.fd = fds[1];
+    options.remote_side = dummy;
+    options.user = new CheckRecycle;
+    ASSERT_EQ(0, brpc::Socket::Create(options, &id));
+    brpc::SocketUniquePtr s;
+    ASSERT_EQ(0, brpc::Socket::Address(id, &s));
+    s->_ssl_state = brpc::SSL_OFF;
+    ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref));
+    global_sock = s.get();
+    ASSERT_TRUE(s.get());
+    ASSERT_EQ(fds[1], s->fd());
+    ASSERT_EQ(dummy, s->remote_side());
+    ASSERT_EQ(id, s->id());
+    ASSERT_FALSE(s->IsWriteShutdown());
+
+    pthread_t rth;
+    ReaderArg reader_arg = { fds[0], 0 };
+    pthread_create(&rth, NULL, reader, &reader_arg);
+
+    bthread_t th[3];
+    ShutdownWriterArg args[ARRAY_SIZE(th)];
+    for (size_t i = 0; i < ARRAY_SIZE(th); ++i) {
+        args[i].times = REP;
+        args[i].socket_id = id;
+        args[i].total_count = 0;
+        args[i].success_count = 0;
+        bthread_start_background(&th[i], NULL, ShutdownWriter, &args[i]);
+    }
+
+    for (size_t i = 0; i < ARRAY_SIZE(th); ++i) {
+        ASSERT_EQ(0, bthread_join(th[i], NULL));
+    }
+    bthread_usleep(50 * 1000);
+
+    ASSERT_TRUE(s->IsWriteShutdown());
+    ASSERT_FALSE(s->Failed());
+    ASSERT_EQ(0, s->SetFailed());
+    s.release()->Dereference();
+    pthread_join(rth, NULL);
+    ASSERT_EQ((brpc::Socket*)NULL, global_sock);
+    close(fds[0]);
+
+    size_t total_count = 0;
+    size_t success_count = 0;
+    for (auto & arg : args) {
+        total_count += arg.total_count;
+        success_count += arg.success_count;
+    }
+    ASSERT_EQ(REP * ARRAY_SIZE(th), total_count);
+    EXPECT_EQ((size_t)1, reader_arg.nread);
+    EXPECT_EQ((size_t)1, success_count);
+}
+
+TEST_F(SocketTest, shutdown_write) {
+    for (int i = 0; i < 100; ++i) {
+        TestShutdownWrite();
+    }
+}
+
+TEST_F(SocketTest, packed_ptr) {
+    brpc::PackedPtr<int> ptr;
+    ASSERT_EQ(nullptr, ptr.get());
+    ASSERT_EQ(0, ptr.extra());
+
+    int a = 1;
+    uint16_t b = 2;
+    ptr.set(&a);
+    ASSERT_EQ(&a, ptr.get());
+    *ptr.get() = b;
+    ASSERT_EQ(a, b);
+    ptr.set_extra(b);
+    ASSERT_EQ(b, ptr.extra());
+    ptr.reset();
+    ptr.reset_extra();
+    ASSERT_EQ(nullptr, ptr.get());
+    ASSERT_EQ(0, ptr.extra());
+
+    int c = 3;
+    uint16_t d = 4;
+    ptr.set_ptr_and_extra(&c, d);
+    ASSERT_EQ(&c, ptr.get());
+    ASSERT_EQ(d, ptr.extra());
+    ptr.reset_ptr_and_extra();
+    ASSERT_EQ(nullptr, ptr.get());
+    ASSERT_EQ(0, ptr.extra());
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@brpc.apache.org
For additional commands, e-mail: dev-h...@brpc.apache.org

Reply via email to