This is an automated email from the ASF dual-hosted git repository.
wwbmmm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git
The following commit(s) were added to refs/heads/master by this push:
new c666ea45 Support shutdown write and notify for success write (#2547)
c666ea45 is described below
commit c666ea45258c18180dff31a8c7def1a865f71a11
Author: Bright Chen <[email protected]>
AuthorDate: Mon Apr 8 11:07:11 2024 +0800
Support shutdown write and notify for success write (#2547)
* Support shutdown write and notify for success write
* Fix some problems
---
src/brpc/controller.cpp | 1 +
src/brpc/errno.proto | 1 +
src/brpc/socket.cpp | 130 +++++++++++++++++++------
src/brpc/socket.h | 91 ++++++++++++++++--
test/brpc_socket_unittest.cpp | 215 +++++++++++++++++++++++++++++++++++++++++-
5 files changed, 402 insertions(+), 36 deletions(-)
diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp
index 42f507fd..1d9b1bb9 100644
--- a/src/brpc/controller.cpp
+++ b/src/brpc/controller.cpp
@@ -80,6 +80,7 @@ BAIDU_REGISTER_ERRNO(brpc::ELOGOFF, "Server is stopping");
BAIDU_REGISTER_ERRNO(brpc::ELIMIT, "Reached server's max_concurrency");
BAIDU_REGISTER_ERRNO(brpc::ECLOSE, "Close socket initiatively");
BAIDU_REGISTER_ERRNO(brpc::EITP, "Bad Itp response");
+BAIDU_REGISTER_ERRNO(brpc::ESHUTDOWNWRITE, "Shutdown write of socket");
#if BRPC_WITH_RDMA
BAIDU_REGISTER_ERRNO(brpc::ERDMA, "RDMA verbs error");
diff --git a/src/brpc/errno.proto b/src/brpc/errno.proto
index fccd8edb..26ffadc2 100644
--- a/src/brpc/errno.proto
+++ b/src/brpc/errno.proto
@@ -49,6 +49,7 @@ enum Errno {
ELIMIT = 2004; // Reached server's limit on resources
ECLOSE = 2005; // Close socket initiatively
EITP = 2006; // Failed Itp response
+ ESHUTDOWNWRITE = 2007; // Shutdown write of socket
// Errno related to RDMA (may happen at both sides)
ERDMA = 3001; // RDMA verbs error
diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp
index c2e3095d..447bac5f 100644
--- a/src/brpc/socket.cpp
+++ b/src/brpc/socket.cpp
@@ -307,33 +307,56 @@ const uint32_t MAX_PIPELINED_COUNT = 16384;
struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
static WriteRequest* const UNCONNECTED;
-
+
butil::IOBuf data;
WriteRequest* next;
bthread_id_t id_wait;
- Socket* socket;
+
+ void clear_and_set_control_bits(bool notify_on_success,
+ bool shutdown_write) {
+ _socket_and_control_bits.set_extra(
+ (uint16_t)notify_on_success << 1 | (uint16_t)shutdown_write);
+ }
+
+ void set_socket(Socket* s) {
+ _socket_and_control_bits.set(s);
+ }
+
+ // If this field is set to true, notify when write successfully.
+ bool is_notify_on_success() const {
+ return _socket_and_control_bits.extra() & ((uint16_t)1 << 1);
+ }
+
+ // Whether shutdown write of the socket after this write complete.
+ bool need_shutdown_write() const {
+ return _socket_and_control_bits.extra() & (uint16_t)1;
+ }
+
+ Socket* get_socket() const {
+ return _socket_and_control_bits.get();
+ }
uint32_t pipelined_count() const {
- return (_pc_and_udmsg >> 48) & 0x3FFF;
+ return _pc_and_udmsg.extra() & 0x3FFF;
}
uint32_t get_auth_flags() const {
- return (_pc_and_udmsg >> 62) & 0x03;
+ return (_pc_and_udmsg.extra() >> 14) & 0x03;
}
void clear_pipelined_count_and_auth_flags() {
- _pc_and_udmsg &= 0xFFFFFFFFFFFFULL;
+ _pc_and_udmsg.reset_extra();
}
SocketMessage* user_message() const {
- return (SocketMessage*)(_pc_and_udmsg & 0xFFFFFFFFFFFFULL);
+ return _pc_and_udmsg.get();
}
void clear_user_message() {
- _pc_and_udmsg &= 0xFFFF000000000000ULL;
+ _pc_and_udmsg.reset();
}
void set_pipelined_count_and_user_message(
uint32_t pc, SocketMessage* msg, uint32_t auth_flags) {
if (auth_flags) {
pc |= (auth_flags & 0x03) << 14;
}
- _pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg;
+ _pc_and_udmsg.set_ptr_and_extra(msg, pc);
}
bool reset_pipelined_count_and_user_message() {
@@ -355,7 +378,10 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
void Setup(Socket* s);
private:
- uint64_t _pc_and_udmsg;
+ // Socket pointer and some control bits.
+ PackedPtr<Socket> _socket_and_control_bits;
+ // User message pointer, pipelined count auth flag.
+ PackedPtr<SocketMessage> _pc_and_udmsg;
};
void Socket::WriteRequest::Setup(Socket* s) {
@@ -399,7 +425,7 @@ public:
EpollOutRequest() : fd(-1), timer_id(0)
, on_epollout_event(NULL), data(NULL) {}
- ~EpollOutRequest() {
+ ~EpollOutRequest() override {
// Remove the timer at last inside destructor to avoid
// race with the place that registers the timer
if (timer_id) {
@@ -408,8 +434,8 @@ public:
}
}
- void BeforeRecycle(Socket*) {
- // Recycle itself
+ void BeforeRecycle(Socket*) override {
+ // Recycle itself.
delete this;
}
@@ -464,6 +490,7 @@ Socket::Socket(Forbidden)
, _unwritten_bytes(0)
, _epollout_butex(NULL)
, _write_head(NULL)
+ , _is_write_shutdown(false)
, _stream_set(NULL)
, _total_streams_unconsumed_size(0)
, _ninflight_app_health_check(0)
@@ -485,7 +512,11 @@ void
Socket::ReturnSuccessfulWriteRequest(Socket::WriteRequest* p) {
const bthread_id_t id_wait = p->id_wait;
butil::return_object(p);
if (id_wait != INVALID_BTHREAD_ID) {
- NotifyOnFailed(id_wait);
+ if (p->is_notify_on_success() && !Failed()) {
+ bthread_id_error(id_wait, 0);
+ } else {
+ NotifyOnFailed(id_wait);
+ }
}
}
@@ -514,11 +545,18 @@ Socket::WriteRequest*
Socket::ReleaseWriteRequestsExceptLast(
}
void Socket::ReleaseAllFailedWriteRequests(Socket::WriteRequest* req) {
- CHECK(Failed());
- pthread_mutex_lock(&_id_wait_list_mutex);
- const int error_code = non_zero_error_code();
- const std::string error_text = _error_text;
- pthread_mutex_unlock(&_id_wait_list_mutex);
+ CHECK(Failed() || IsWriteShutdown());
+ int error_code;
+ std::string error_text;
+ if (Failed()) {
+ pthread_mutex_lock(&_id_wait_list_mutex);
+ error_code = non_zero_error_code();
+ error_text = _error_text;
+ pthread_mutex_unlock(&_id_wait_list_mutex);
+ } else {
+ error_code = ESHUTDOWNWRITE;
+ error_text = "Shutdown write of the socket";
+ }
// Notice that `req' is not tail if Address after IsWriteComplete fails.
do {
req = ReleaseWriteRequestsExceptLast(req, error_code, error_text);
@@ -746,6 +784,7 @@ int Socket::Create(const SocketOptions& options, SocketId*
id) {
m->_keepalive_options = options.keepalive_options;
m->_bthread_tag = options.bthread_tag;
CHECK(NULL == m->_write_head.load(butil::memory_order_relaxed));
+ m->_is_write_shutdown = false;
// Must be last one! Internal fields of this Socket may be access
// just after calling ResetFileDescriptor.
if (m->ResetFileDescriptor(options.fd) != 0) {
@@ -1382,7 +1421,7 @@ int Socket::ConnectIfNot(const timespec* abstime,
WriteRequest* req) {
// Have to hold a reference for `req'
SocketUniquePtr s;
ReAddress(&s);
- req->socket = s.get();
+ req->set_socket(s.get());
if (_conn) {
if (_conn->Connect(this, abstime, KeepWriteIfConnected, req) < 0) {
return -1;
@@ -1454,7 +1493,7 @@ int Socket::HandleEpollOutRequest(int error_code,
EpollOutRequest* req) {
void Socket::AfterAppConnected(int err, void* data) {
WriteRequest* req = static_cast<WriteRequest*>(data);
if (err == 0) {
- Socket* const s = req->socket;
+ Socket* const s = req->get_socket();
SharedPart* sp = s->GetSharedPart();
if (sp) {
sp->num_continuous_connect_timeouts.store(0,
butil::memory_order_relaxed);
@@ -1468,7 +1507,7 @@ void Socket::AfterAppConnected(int err, void* data) {
KeepWrite(req);
}
} else {
- SocketUniquePtr s(req->socket);
+ SocketUniquePtr s(req->get_socket());
if (err == ETIMEDOUT) {
SharedPart* sp = s->GetOrNewSharedPart();
if (sp->num_continuous_connect_timeouts.fetch_add(
@@ -1496,7 +1535,7 @@ static void* RunClosure(void* arg) {
int Socket::KeepWriteIfConnected(int fd, int err, void* data) {
WriteRequest* req = static_cast<WriteRequest*>(data);
- Socket* s = req->socket;
+ Socket* s = req->get_socket();
if (err == 0 && s->ssl_state() == SSL_CONNECTING) {
// Run ssl connect in a new bthread to avoid blocking
// the current bthread (thus blocking the EventDispatcher)
@@ -1519,7 +1558,7 @@ int Socket::KeepWriteIfConnected(int fd, int err, void*
data) {
void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) {
butil::fd_guard sockfd(fd);
WriteRequest* req = static_cast<WriteRequest*>(data);
- Socket* s = req->socket;
+ Socket* s = req->get_socket();
CHECK_GE(sockfd, 0);
if (err == 0 && s->CheckConnected(sockfd) == 0
&& s->ResetFileDescriptor(sockfd) == 0) {
@@ -1527,7 +1566,8 @@ void Socket::CheckConnectedAndKeepWrite(int fd, int err,
void* data) {
g_vars->channel_conn << 1;
}
if (s->_app_connect) {
- s->_app_connect->StartConnect(req->socket, AfterAppConnected, req);
+ s->_app_connect->StartConnect(req->get_socket(),
+ AfterAppConnected, req);
} else {
// Successfully created a connection
AfterAppConnected(0, req);
@@ -1614,6 +1654,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions*
options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
+ req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
req->set_pipelined_count_and_user_message(
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags);
return StartWrite(req, opt);
@@ -1650,7 +1691,9 @@ int Socket::Write(SocketMessagePtr<>& msg, const
WriteOptions* options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
- req->set_pipelined_count_and_user_message(opt.pipelined_count,
msg.release(), opt.auth_flags);
+ req->clear_and_set_control_bits(opt.notify_on_success, opt.shutdown_write);
+ req->set_pipelined_count_and_user_message(
+ opt.pipelined_count, msg.release(), opt.auth_flags);
return StartWrite(req, opt);
}
@@ -1672,12 +1715,19 @@ int Socket::StartWrite(WriteRequest* req, const
WriteOptions& opt) {
bthread_t th;
SocketUniquePtr ptr_for_keep_write;
ssize_t nw = 0;
+ int ret = 0;
// We've got the right to write.
req->next = NULL;
+
+ // Fast fail when write has been shutdown.
+ if (_is_write_shutdown) {
+ goto FAIL_TO_WRITE;
+ }
+ _is_write_shutdown = req->need_shutdown_write();
// Connect to remote_side() if not.
- int ret = ConnectIfNot(opt.abstime, req);
+ ret = ConnectIfNot(opt.abstime, req);
if (ret < 0) {
saved_errno = errno;
SetFailed(errno, "Fail to connect %s directly: %m",
description().c_str());
@@ -1736,7 +1786,7 @@ int Socket::StartWrite(WriteRequest* req, const
WriteOptions& opt) {
KEEPWRITE_IN_BACKGROUND:
ReAddress(&ptr_for_keep_write);
- req->socket = ptr_for_keep_write.release();
+ req->set_socket(ptr_for_keep_write.release());
if (bthread_start_background(&th, &BTHREAD_ATTR_NORMAL,
KeepWrite, req) != 0) {
LOG(FATAL) << "Fail to start KeepWrite";
@@ -1758,7 +1808,7 @@ static const size_t DATA_LIST_MAX = 256;
void* Socket::KeepWrite(void* void_arg) {
g_vars->nkeepwrite << 1;
WriteRequest* req = static_cast<WriteRequest*>(void_arg);
- SocketUniquePtr s(req->socket);
+ SocketUniquePtr s(req->get_socket());
// When error occurs, spin until there's no more requests instead of
// returning directly otherwise _write_head is permantly non-NULL which
@@ -1766,11 +1816,18 @@ void* Socket::KeepWrite(void* void_arg) {
WriteRequest* cur_tail = NULL;
do {
// req was written, skip it.
+ bool need_shutdown = false;
if (req->next != NULL && req->data.empty()) {
WriteRequest* const saved_req = req;
+ need_shutdown = req->need_shutdown_write();
req = req->next;
s->ReturnSuccessfulWriteRequest(saved_req);
}
+ if (need_shutdown) {
+ LOG(WARNING) << "Shutdown write of " << *s;
+ break;
+ }
+
const ssize_t nw = s->DoWrite(req);
if (nw < 0) {
if (errno != EAGAIN && errno != EOVERCROWDED) {
@@ -1783,11 +1840,19 @@ void* Socket::KeepWrite(void* void_arg) {
} else {
s->AddOutputBytes(nw);
}
- // Release WriteRequest until non-empty data or last request.
+ // Release WriteRequest until non-empty data, last request or shutdown
write.
while (req->next != NULL && req->data.empty()) {
WriteRequest* const saved_req = req;
+ need_shutdown = req->need_shutdown_write();
req = req->next;
s->ReturnSuccessfulWriteRequest(saved_req);
+ if (need_shutdown) {
+ break;
+ }
+ }
+ if (need_shutdown) {
+ LOG(WARNING) << "Shutdown write of " << *s;
+ break;
}
// TODO(gejun): wait for epollout when we actually have written
// all the data. This weird heuristic reduces 30us delay...
@@ -1867,6 +1932,11 @@ ssize_t Socket::DoWrite(WriteRequest* req) {
for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX;
p = p->next) {
data_list[ndata++] = &p->data;
+ if (p->need_shutdown_write()) {
+ // Write WriteRequest until shutdown write.
+ _is_write_shutdown = true;
+ break;
+ }
}
if (ssl_state() == SSL_OFF) {
@@ -2387,6 +2457,8 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) {
os << "\n}";
}
+ os << "\nis_wirte_shutdown=" << ptr->_is_write_shutdown;
+
{
int keepalive = 0;
socklen_t len = sizeof(keepalive);
diff --git a/src/brpc/socket.h b/src/brpc/socket.h
index faf6baac..97ce5685 100644
--- a/src/brpc/socket.h
+++ b/src/brpc/socket.h
@@ -166,6 +166,62 @@ struct PipelinedInfo {
bthread_id_t id_wait;
};
+// A data structure packed with a pointer and
+// some extra information using a uint64 variable.
+template <class T>
+class PackedPtr {
+ static constexpr uint8_t MAX_POINTER_LEN = 48;
+ static constexpr uint64_t POINTER_MASK = ((uint64_t)1 << MAX_POINTER_LEN)
- 1;
+ static constexpr uint64_t EXTRA_MASK = ~POINTER_MASK;
+public:
+ PackedPtr() : _data(0) {
+ BAIDU_CASSERT(sizeof(PackedPtr) == 8, sizeof_packed_ptr_must_be_8);
+ }
+
+ void set(T* ptr) {
+ // Clear the low 48 bits and then
+ // store the pointer in the low 48 bits.
+ _data = (_data & EXTRA_MASK) |
+ ((uint64_t)(uintptr_t)ptr & POINTER_MASK);
+ }
+
+ void reset() {
+ // Clear the low 48 bits.
+ _data &= EXTRA_MASK;
+ }
+
+ T* get() const { return (T*)(_data & POINTER_MASK); }
+
+ void set_extra(uint16_t extra) {
+ // Clear the high 16 bits and then
+ // store the extra in the high 16 bits.
+ _data = (_data & POINTER_MASK) |
+ ((uint64_t)extra << MAX_POINTER_LEN);
+ }
+
+ void reset_extra() {
+ // Clear the high 16 bits.
+ _data &= POINTER_MASK;
+ }
+
+ uint16_t extra() const { return _data >> MAX_POINTER_LEN; }
+
+ void set_ptr_and_extra(T* p, uint16_t extra) {
+ _data = ((uint64_t)(uintptr_t)p & POINTER_MASK) |
+ ((uint64_t)extra << MAX_POINTER_LEN);
+ }
+
+ void reset_ptr_and_extra() {
+ _data = 0;
+ }
+
+private:
+ // Pointer is stored in the low 48 bits,
+ // extra information is stored in the high 16 bits.
+ uint64_t _data;
+};
+
+
struct SocketSSLContext {
SocketSSLContext();
~SocketSSLContext();
@@ -269,11 +325,18 @@ public:
// - Write once when uncontended(most cases).
// - Wait-free when contended.
struct WriteOptions {
- // `id_wait' is signalled when this Socket is SetFailed. To disable
- // the signal, set this field to INVALID_BTHREAD_ID.
- // `on_reset' of `id_wait' is NOT called when Write() returns non-zero.
+ // `id_wait' is signalled when this Socket is SetFailed or data is
written
+ // successfully with `notify_on_success=true'. To disable the signal,
set
+ // this field to INVALID_BTHREAD_ID. `on_reset' of `id_wait' is NOT
called
+ // when Write() returns non-zero.
// Default: INVALID_BTHREAD_ID
bthread_id_t id_wait;
+
+ // If this field is set to true and `id_wait' is not
INVALID_BTHREAD_ID,
+ // `id_wait' can be signalled when write successfully.
+ // Default: false
+ bool notify_on_success;
+
// If no connection exists, a connection will be established to
// remote_side() regarding deadline `abstime'. NULL means no timeout.
// Default: NULL
@@ -301,13 +364,27 @@ public:
// performance. Otherwise, each write only writes one `msg` into socket
// and no KeepWrite thread can be created, which brings poor
// performance.
+ // Default: false
bool write_in_background;
+ // After this write complete, shutdown write of the socket.
+ // Default: false
+ bool shutdown_write;
+
WriteOptions()
- : id_wait(INVALID_BTHREAD_ID), abstime(NULL)
- , pipelined_count(0), auth_flags(0)
- , ignore_eovercrowded(false), write_in_background(false) {}
+ : id_wait(INVALID_BTHREAD_ID)
+ , notify_on_success(false)
+ , abstime(NULL)
+ , pipelined_count(0)
+ , auth_flags(0)
+ , ignore_eovercrowded(false)
+ , write_in_background(false)
+ , shutdown_write(false) {}
};
+
+ // True if write of socket is shutdown.
+ bool IsWriteShutdown() const { return _is_write_shutdown; }
+
int Write(butil::IOBuf *msg, const WriteOptions* options = NULL);
// Write an user-defined message. `msg' is released when Write() is
@@ -917,6 +994,8 @@ private:
// Storing data that are not flushed into `fd' yet.
butil::atomic<WriteRequest*> _write_head;
+ bool _is_write_shutdown;
+
butil::Mutex _stream_mutex;
std::set<StreamId> *_stream_set;
butil::atomic<int64_t> _total_streams_unconsumed_size;
diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp
index d2258735..f278c46b 100644
--- a/test/brpc_socket_unittest.cpp
+++ b/test/brpc_socket_unittest.cpp
@@ -1225,7 +1225,6 @@ TEST_F(SocketTest, keepalive) {
}
}
-
TEST_F(SocketTest, keepalive_input_message) {
int default_keepalive = 0;
int default_keepalive_idle = 0;
@@ -1418,3 +1417,217 @@ TEST_F(SocketTest, keepalive_input_message) {
sockfd.release();
}
}
+
+int HandleSocketSuccessWrite(bthread_id_t id, void* data, int error_code,
+ const std::string& error_text) {
+ auto success_count = static_cast<size_t*>(data);
+ EXPECT_NE(nullptr, success_count);
+ EXPECT_EQ(0, error_code);
+ ++(*success_count);
+ CHECK_EQ(0, bthread_id_unlock_and_destroy(id));
+ return 0;
+}
+
+TEST_F(SocketTest, notify_on_success) {
+ const size_t REP = 10000;
+ int fds[2];
+ ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
+
+ brpc::SocketId id = 8888;
+ butil::EndPoint dummy;
+ ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy));
+ brpc::SocketOptions options;
+ options.fd = fds[1];
+ options.remote_side = dummy;
+ options.user = new CheckRecycle;
+ ASSERT_EQ(0, brpc::Socket::Create(options, &id));
+ brpc::SocketUniquePtr s;
+ ASSERT_EQ(0, brpc::Socket::Address(id, &s));
+ s->_ssl_state = brpc::SSL_OFF;
+ ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref));
+ global_sock = s.get();
+ ASSERT_TRUE(s.get());
+ ASSERT_EQ(fds[1], s->fd());
+ ASSERT_EQ(dummy, s->remote_side());
+ ASSERT_EQ(id, s->id());
+
+ pthread_t rth;
+ ReaderArg reader_arg = { fds[0], 0 };
+ pthread_create(&rth, NULL, reader, &reader_arg);
+
+ size_t success_count = 0;
+ char buf[] = "hello reader side!";
+ for (size_t c = 0; c < REP; ++c) {
+ bthread_id_t write_id;
+ ASSERT_EQ(0, bthread_id_create2(&write_id, &success_count,
+ HandleSocketSuccessWrite));
+ brpc::Socket::WriteOptions wopt;
+ wopt.id_wait = write_id;
+ wopt.notify_on_success = true;
+ butil::IOBuf src;
+ src.append(buf, 16);
+ if (s->Write(&src, &wopt) != 0) {
+ if (errno == brpc::EOVERCROWDED) {
+ // The buf is full, sleep a while and retry.
+ bthread_usleep(1000);
+ --c;
+ continue;
+ }
+ PLOG(ERROR) << "Fail to write into SocketId=" << id;
+ break;
+ }
+ }
+ bthread_usleep(1000 * 1000);
+
+ ASSERT_EQ(0, s->SetFailed());
+ s.release()->Dereference();
+ pthread_join(rth, NULL);
+ ASSERT_EQ(REP, success_count);
+ ASSERT_EQ((brpc::Socket*)NULL, global_sock);
+ close(fds[0]);
+}
+
+struct ShutdownWriterArg {
+ size_t times;
+ brpc::SocketId socket_id;
+ butil::atomic<int> total_count;
+ butil::atomic<int> success_count;
+};
+
+int HandleSocketShutdownWrite(bthread_id_t id, void* data, int error_code,
+ const std::string& error_text) {
+ auto arg = static_cast<ShutdownWriterArg*>(data);
+ EXPECT_NE(nullptr, arg);
+ EXPECT_TRUE(0 == error_code || brpc::ESHUTDOWNWRITE == error_code) <<
error_code;
+ ++arg->total_count;
+ if (0 == error_code) {
+ ++arg->success_count;
+ }
+ CHECK_EQ(0, bthread_id_unlock_and_destroy(id));
+ return 0;
+}
+
+void* ShutdownWriter(void* void_arg) {
+ auto arg = static_cast<ShutdownWriterArg*>(void_arg);
+ brpc::SocketUniquePtr sock;
+ if (brpc::Socket::Address(arg->socket_id, &sock) < 0) {
+ LOG(INFO) << "Fail to address SocketId=" << arg->socket_id;
+ return NULL;
+ }
+ for (size_t c = 0; c < arg->times; ++c) {
+ bthread_id_t write_id;
+ EXPECT_EQ(0, bthread_id_create2(&write_id, arg,
+ HandleSocketShutdownWrite));
+ brpc::Socket::WriteOptions wopt;
+ wopt.id_wait = write_id;
+ wopt.notify_on_success = true;
+ wopt.shutdown_write = true;
+ butil::IOBuf src;
+ src.push_back('a');
+ if (sock->Write(&src, &wopt) != 0) {
+ if (errno == brpc::EOVERCROWDED) {
+ // The buf is full, sleep a while and retry.
+ bthread_usleep(1000);
+ --c;
+ continue;
+ }
+ }
+ }
+ return NULL;
+}
+
+void TestShutdownWrite() {
+ const size_t REP = 100;
+ int fds[2];
+ ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
+
+ brpc::SocketId id = 8888;
+ butil::EndPoint dummy;
+ ASSERT_EQ(0, str2endpoint("192.168.1.26:8080", &dummy));
+ brpc::SocketOptions options;
+ options.fd = fds[1];
+ options.remote_side = dummy;
+ options.user = new CheckRecycle;
+ ASSERT_EQ(0, brpc::Socket::Create(options, &id));
+ brpc::SocketUniquePtr s;
+ ASSERT_EQ(0, brpc::Socket::Address(id, &s));
+ s->_ssl_state = brpc::SSL_OFF;
+ ASSERT_EQ(2, brpc::NRefOfVRef(s->_versioned_ref));
+ global_sock = s.get();
+ ASSERT_TRUE(s.get());
+ ASSERT_EQ(fds[1], s->fd());
+ ASSERT_EQ(dummy, s->remote_side());
+ ASSERT_EQ(id, s->id());
+ ASSERT_FALSE(s->IsWriteShutdown());
+
+ pthread_t rth;
+ ReaderArg reader_arg = { fds[0], 0 };
+ pthread_create(&rth, NULL, reader, &reader_arg);
+
+ bthread_t th[3];
+ ShutdownWriterArg args[ARRAY_SIZE(th)];
+ for (size_t i = 0; i < ARRAY_SIZE(th); ++i) {
+ args[i].times = REP;
+ args[i].socket_id = id;
+ args[i].total_count = 0;
+ args[i].success_count = 0;
+ bthread_start_background(&th[i], NULL, ShutdownWriter, &args[i]);
+ }
+
+ for (size_t i = 0; i < ARRAY_SIZE(th); ++i) {
+ ASSERT_EQ(0, bthread_join(th[i], NULL));
+ }
+ bthread_usleep(50 * 1000);
+
+ ASSERT_TRUE(s->IsWriteShutdown());
+ ASSERT_FALSE(s->Failed());
+ ASSERT_EQ(0, s->SetFailed());
+ s.release()->Dereference();
+ pthread_join(rth, NULL);
+ ASSERT_EQ((brpc::Socket*)NULL, global_sock);
+ close(fds[0]);
+
+ size_t total_count = 0;
+ size_t success_count = 0;
+ for (auto & arg : args) {
+ total_count += arg.total_count;
+ success_count += arg.success_count;
+ }
+ ASSERT_EQ(REP * ARRAY_SIZE(th), total_count);
+ EXPECT_EQ((size_t)1, reader_arg.nread);
+ EXPECT_EQ((size_t)1, success_count);
+}
+
+TEST_F(SocketTest, shutdown_write) {
+ for (int i = 0; i < 100; ++i) {
+ TestShutdownWrite();
+ }
+}
+
+TEST_F(SocketTest, packed_ptr) {
+ brpc::PackedPtr<int> ptr;
+ ASSERT_EQ(nullptr, ptr.get());
+ ASSERT_EQ(0, ptr.extra());
+
+ int a = 1;
+ uint16_t b = 2;
+ ptr.set(&a);
+ ASSERT_EQ(&a, ptr.get());
+ *ptr.get() = b;
+ ASSERT_EQ(a, b);
+ ptr.set_extra(b);
+ ASSERT_EQ(b, ptr.extra());
+ ptr.reset();
+ ptr.reset_extra();
+ ASSERT_EQ(nullptr, ptr.get());
+ ASSERT_EQ(0, ptr.extra());
+
+ int c = 3;
+ uint16_t d = 4;
+ ptr.set_ptr_and_extra(&c, d);
+ ASSERT_EQ(&c, ptr.get());
+ ASSERT_EQ(d, ptr.extra());
+ ptr.reset_ptr_and_extra();
+ ASSERT_EQ(nullptr, ptr.get());
+ ASSERT_EQ(0, ptr.extra());
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]