This is an automated email from the ASF dual-hosted git repository. guangmingchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/brpc.git
The following commit(s) were added to refs/heads/master by this push: new ba9e8388 Fix thread safety of Wrapper (#2952) ba9e8388 is described below commit ba9e8388b670f7df0eae7a9bb51da5fac08435e2 Author: Bright Chen <chenguangmin...@foxmail.com> AuthorDate: Mon Apr 21 09:52:24 2025 +0800 Fix thread safety of Wrapper (#2952) --- src/brpc/policy/http_rpc_protocol.cpp | 6 +- src/brpc/server.cpp | 2 +- src/butil/containers/doubly_buffered_data.h | 152 ++++++++++++++-------------- src/json2pb/pb_to_json.cpp | 4 +- 4 files changed, 78 insertions(+), 86 deletions(-) diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index e0ff1e31..007bce39 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -331,11 +331,7 @@ static bool ProtoMessageToProtoJson(const google::protobuf::Message& message, butil::IOBufAsZeroCopyOutputStream* wrapper, Controller* cntl, int error_code) { json2pb::Pb2ProtoJsonOptions options; -#if GOOGLE_PROTOBUF_VERSION >= 5026002 - options.always_print_fields_with_no_presence = cntl->has_always_print_primitive_fields(); -#else - options.always_print_primitive_fields = cntl->has_always_print_primitive_fields(); -#endif + AlwaysPrintPrimitiveFields(options) = cntl->has_always_print_primitive_fields(); options.always_print_enums_as_ints = FLAGS_pb_enum_as_number; std::string error; bool ok = json2pb::ProtoMessageToProtoJson(message, wrapper, options, &error); diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index b6c3d3ee..aa55c858 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -1972,7 +1972,7 @@ bool IsDummyServerRunning() { } const Server::MethodProperty* -Server::FindMethodPropertyByFullName(const butil::StringPiece&fullname) const { +Server::FindMethodPropertyByFullName(const butil::StringPiece& fullname) const { return _method_map.seek(fullname); } diff --git a/src/butil/containers/doubly_buffered_data.h b/src/butil/containers/doubly_buffered_data.h index 5aacece3..ff96a903 100644 --- a/src/butil/containers/doubly_buffered_data.h +++ b/src/butil/containers/doubly_buffered_data.h @@ -21,7 +21,8 @@ #define BUTIL_DOUBLY_BUFFERED_DATA_H #include <deque> -#include <vector> // std::vector +#include <vector> +#include <memory> #include <pthread.h> #include "butil/scoped_lock.h" #include "butil/thread_local.h" @@ -87,6 +88,8 @@ class DoublyBufferedData { class Wrapper; class WrapperTLSGroup; typedef int WrapperTLSId; + typedef std::shared_ptr<Wrapper> WrapperSharedPtr; + typedef std::weak_ptr<Wrapper> WrapperWeakPtr; public: class ScopedPtr { friend class DoublyBufferedData; @@ -111,7 +114,7 @@ public: const T* _data; // Index of foreground instance used by ScopedPtr. int _index; - Wrapper* _w; + WrapperSharedPtr _w; }; DoublyBufferedData(); @@ -152,8 +155,7 @@ private: return _data + index; } - Wrapper* AddWrapper(Wrapper*); - void RemoveWrapper(Wrapper*); + WrapperSharedPtr GetWrapper(); // Foreground and background void. T _data[2]; @@ -165,7 +167,7 @@ private: WrapperTLSId _wrapper_key; // All thread-local instances. - std::vector<Wrapper*> _wrappers; + std::vector<WrapperWeakPtr> _wrappers; // Sequence access to _wrappers. pthread_mutex_t _wrappers_mutex{}; @@ -195,18 +197,22 @@ class DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperTLSGroup { public: const static size_t RAW_BLOCK_SIZE = 4096; const static size_t ELEMENTS_PER_BLOCK = - RAW_BLOCK_SIZE / sizeof(Wrapper) > 0 ? RAW_BLOCK_SIZE / sizeof(Wrapper) : 1; + RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) > 0 ? + RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) : 1; struct BAIDU_CACHELINE_ALIGNMENT ThreadBlock { - inline DoublyBufferedData::Wrapper* at(size_t offset) { - return _data + offset; + WrapperSharedPtr at(size_t offset) { + if (NULL == _data[offset]) { + _data[offset] = std::make_shared<Wrapper>(); + } + return _data[offset]; }; private: - DoublyBufferedData::Wrapper _data[ELEMENTS_PER_BLOCK]; + WrapperSharedPtr _data[ELEMENTS_PER_BLOCK]; }; - inline static WrapperTLSId key_create() { + static WrapperTLSId key_create() { BAIDU_SCOPED_LOCK(_s_mutex); WrapperTLSId id = 0; if (!_get_free_ids().empty()) { @@ -218,7 +224,7 @@ public: return id; } - inline static int key_delete(WrapperTLSId id) { + static int key_delete(WrapperTLSId id) { BAIDU_SCOPED_LOCK(_s_mutex); if (id < 0 || id >= _s_id) { errno = EINVAL; @@ -228,17 +234,13 @@ public: return 0; } - inline static DoublyBufferedData::Wrapper* get_or_create_tls_data(WrapperTLSId id) { + static WrapperSharedPtr get_or_create_tls_data(WrapperTLSId id) { if (BAIDU_UNLIKELY(id < 0)) { CHECK(false) << "Invalid id=" << id; return NULL; } if (_s_tls_blocks == NULL) { - _s_tls_blocks = new (std::nothrow) std::vector<ThreadBlock*>; - if (BAIDU_UNLIKELY(_s_tls_blocks == NULL)) { - LOG(FATAL) << "Fail to create vector, " << berror(); - return NULL; - } + _s_tls_blocks = new std::vector<ThreadBlock*>; butil::thread_atexit(_destroy_tls_blocks); } const size_t block_id = (size_t)id / ELEMENTS_PER_BLOCK; @@ -248,12 +250,8 @@ public: } ThreadBlock* tb = (*_s_tls_blocks)[block_id]; if (tb == NULL) { - ThreadBlock* new_block = new (std::nothrow) ThreadBlock; - if (BAIDU_UNLIKELY(new_block == NULL)) { - return NULL; - } - tb = new_block; - (*_s_tls_blocks)[block_id] = new_block; + tb = new ThreadBlock; + (*_s_tls_blocks)[block_id] = tb; } return tb->at(id - block_id * ELEMENTS_PER_BLOCK); } @@ -316,10 +314,6 @@ public: } ~Wrapper() { - if (_control != NULL) { - _control->RemoveWrapper(this); - } - if (AllowBthreadSuspended) { WaitReadDone(0); WaitReadDone(1); @@ -406,9 +400,9 @@ private: // Called when thread initializes thread-local wrapper. template <typename T, typename TLS, bool AllowBthreadSuspended> -typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* -DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper( - typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) { +typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperSharedPtr +DoublyBufferedData<T, TLS, AllowBthreadSuspended>::GetWrapper() { + WrapperSharedPtr w = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key); if (NULL == w) { return NULL; } @@ -423,29 +417,19 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper( w->_control = this; BAIDU_SCOPED_LOCK(_wrappers_mutex); _wrappers.push_back(w); + // The chance to remove expired weak_ptr. + _wrappers.erase( + std::remove_if(_wrappers.begin(), _wrappers.end(), + [](const WrapperWeakPtr& w) { + return w.expired(); + }), + _wrappers.end()); } catch (std::exception& e) { return NULL; } return w; } -// Called when thread quits. -template <typename T, typename TLS, bool AllowBthreadSuspended> -void DoublyBufferedData<T, TLS, AllowBthreadSuspended>::RemoveWrapper( - typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) { - if (NULL == w) { - return; - } - BAIDU_SCOPED_LOCK(_wrappers_mutex); - for (size_t i = 0; i < _wrappers.size(); ++i) { - if (_wrappers[i] == w) { - _wrappers[i] = _wrappers.back(); - _wrappers.pop_back(); - return; - } - } -} - template <typename T, typename TLS, bool AllowBthreadSuspended> DoublyBufferedData<T, TLS, AllowBthreadSuspended>::DoublyBufferedData() : _index(0) @@ -474,7 +458,10 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::~DoublyBufferedData() { { BAIDU_SCOPED_LOCK(_wrappers_mutex); for (size_t i = 0; i < _wrappers.size(); ++i) { - _wrappers[i]->_control = NULL; // hack: disable removal. + WrapperSharedPtr w = _wrappers[i].lock(); + if (NULL != w) { + w->_control = NULL; // hack: disable removal. + } } _wrappers.clear(); } @@ -487,29 +474,28 @@ DoublyBufferedData<T, TLS, AllowBthreadSuspended>::~DoublyBufferedData() { template <typename T, typename TLS, bool AllowBthreadSuspended> int DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Read( typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::ScopedPtr* ptr) { - Wrapper* p = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key); - Wrapper* w = AddWrapper(p); - if (BAIDU_LIKELY(w != NULL)) { - if (AllowBthreadSuspended) { - // Use reference count instead of mutex to indicate read of - // foreground instance, so during the read process, there is - // no need to lock mutex and bthread is allowed to be suspended. - w->BeginRead(); - int index = -1; - ptr->_data = UnsafeRead(index); - ptr->_index = index; - w->AddRef(index); - ptr->_w = w; - w->BeginReadRelease(); - } else { - w->BeginRead(); - ptr->_data = UnsafeRead(); - ptr->_w = w; - } + WrapperSharedPtr w = GetWrapper(); + if (BAIDU_UNLIKELY(w == NULL)) { + return -1; + } - return 0; + if (AllowBthreadSuspended) { + // Use reference count instead of mutex to indicate read of + // foreground instance, so during the read process, there is + // no need to lock mutex and bthread is allowed to be suspended. + w->BeginRead(); + int index = -1; + ptr->_data = UnsafeRead(index); + ptr->_index = index; + w->AddRef(index); + ptr->_w = w; + w->BeginReadRelease(); + } else { + w->BeginRead(); + ptr->_data = UnsafeRead(); + ptr->_w = w; } - return -1; + return 0; } template <typename T, typename TLS, bool AllowBthreadSuspended> @@ -530,7 +516,7 @@ template <typename Fn, typename... Args> size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn, Args&&... args) { // _modify_mutex sequences modifications. Using a separate mutex rather // than _wrappers_mutex is to avoid blocking threads calling - // AddWrapper() or RemoveWrapper() too long. Most of the time, modifications + // GetWrapper() too long. Most of the time, modifications // are done by one thread, contention should be negligible. BAIDU_SCOPED_LOCK(_modify_mutex); int bg_index = !_index.load(butil::memory_order_relaxed); @@ -552,14 +538,24 @@ size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn, Args&& // read, they should see updated _index. { BAIDU_SCOPED_LOCK(_wrappers_mutex); - for (size_t i = 0; i < _wrappers.size(); ++i) { - // Wait read of old foreground instance done. - if (AllowBthreadSuspended) { - _wrappers[i]->WaitReadDone(bg_index); - } else { - _wrappers[i]->WaitReadDone(); - } - } + // The chance to remove expired weak_ptr. + _wrappers.erase( + std::remove_if(_wrappers.begin(), _wrappers.end(), + [bg_index](const WrapperWeakPtr& weak) { + WrapperSharedPtr w = weak.lock(); + bool expired = NULL == w; + if (!expired) { + // Notify all threads waiting for read done. + if (AllowBthreadSuspended) { + w->WaitReadDone(bg_index); + } else { + w->WaitReadDone(); + } + } + // Remove expired weak_ptr. + return expired; + }), + _wrappers.end()); } const size_t ret2 = fn(_data[bg_index], std::forward<Args>(args)...); diff --git a/src/json2pb/pb_to_json.cpp b/src/json2pb/pb_to_json.cpp index c23ccdf7..e37cc87d 100644 --- a/src/json2pb/pb_to_json.cpp +++ b/src/json2pb/pb_to_json.cpp @@ -336,14 +336,14 @@ bool ProtoMessageToJson(const google::protobuf::Message& message, } bool ProtoMessageToJson(const google::protobuf::Message& message, - google::protobuf::io::ZeroCopyOutputStream *stream, + google::protobuf::io::ZeroCopyOutputStream* stream, const Pb2JsonOptions& options, std::string* error) { json2pb::ZeroCopyStreamWriter wrapper(stream); return json2pb::ProtoMessageToJsonStream(message, options, wrapper, error); } bool ProtoMessageToJson(const google::protobuf::Message& message, - google::protobuf::io::ZeroCopyOutputStream *stream, + google::protobuf::io::ZeroCopyOutputStream* stream, std::string* error) { return ProtoMessageToJson(message, stream, Pb2JsonOptions(), error); } --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@brpc.apache.org For additional commands, e-mail: dev-h...@brpc.apache.org