This is an automated email from the ASF dual-hosted git repository.

guangmingchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new ba9e8388 Fix thread safety of Wrapper (#2952)
ba9e8388 is described below

commit ba9e8388b670f7df0eae7a9bb51da5fac08435e2
Author: Bright Chen <chenguangmin...@foxmail.com>
AuthorDate: Mon Apr 21 09:52:24 2025 +0800

    Fix thread safety of Wrapper (#2952)
---
 src/brpc/policy/http_rpc_protocol.cpp       |   6 +-
 src/brpc/server.cpp                         |   2 +-
 src/butil/containers/doubly_buffered_data.h | 152 ++++++++++++++--------------
 src/json2pb/pb_to_json.cpp                  |   4 +-
 4 files changed, 78 insertions(+), 86 deletions(-)

diff --git a/src/brpc/policy/http_rpc_protocol.cpp 
b/src/brpc/policy/http_rpc_protocol.cpp
index e0ff1e31..007bce39 100644
--- a/src/brpc/policy/http_rpc_protocol.cpp
+++ b/src/brpc/policy/http_rpc_protocol.cpp
@@ -331,11 +331,7 @@ static bool ProtoMessageToProtoJson(const 
google::protobuf::Message& message,
                                     butil::IOBufAsZeroCopyOutputStream* 
wrapper,
                                     Controller* cntl, int error_code) {
     json2pb::Pb2ProtoJsonOptions options;
-#if GOOGLE_PROTOBUF_VERSION >= 5026002
-    options.always_print_fields_with_no_presence = 
cntl->has_always_print_primitive_fields();
-#else
-    options.always_print_primitive_fields = 
cntl->has_always_print_primitive_fields();
-#endif
+    AlwaysPrintPrimitiveFields(options) = 
cntl->has_always_print_primitive_fields();
     options.always_print_enums_as_ints = FLAGS_pb_enum_as_number;
     std::string error;
     bool ok = json2pb::ProtoMessageToProtoJson(message, wrapper, options, 
&error);
diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp
index b6c3d3ee..aa55c858 100644
--- a/src/brpc/server.cpp
+++ b/src/brpc/server.cpp
@@ -1972,7 +1972,7 @@ bool IsDummyServerRunning() {
 }
 
 const Server::MethodProperty*
-Server::FindMethodPropertyByFullName(const butil::StringPiece&fullname) const  
{
+Server::FindMethodPropertyByFullName(const butil::StringPiece& fullname) const 
 {
     return _method_map.seek(fullname);
 }
 
diff --git a/src/butil/containers/doubly_buffered_data.h 
b/src/butil/containers/doubly_buffered_data.h
index 5aacece3..ff96a903 100644
--- a/src/butil/containers/doubly_buffered_data.h
+++ b/src/butil/containers/doubly_buffered_data.h
@@ -21,7 +21,8 @@
 #define BUTIL_DOUBLY_BUFFERED_DATA_H
 
 #include <deque>
-#include <vector>                                       // std::vector
+#include <vector>
+#include <memory>
 #include <pthread.h>
 #include "butil/scoped_lock.h"
 #include "butil/thread_local.h"
@@ -87,6 +88,8 @@ class DoublyBufferedData {
     class Wrapper;
     class WrapperTLSGroup;
     typedef int WrapperTLSId;
+    typedef std::shared_ptr<Wrapper> WrapperSharedPtr;
+    typedef std::weak_ptr<Wrapper> WrapperWeakPtr;
 public:
     class ScopedPtr {
     friend class DoublyBufferedData;
@@ -111,7 +114,7 @@ public:
         const T* _data;
         // Index of foreground instance used by ScopedPtr.
         int _index;
-        Wrapper* _w;
+        WrapperSharedPtr _w;
     };
     
     DoublyBufferedData();
@@ -152,8 +155,7 @@ private:
         return _data + index;
     }
 
-    Wrapper* AddWrapper(Wrapper*);
-    void RemoveWrapper(Wrapper*);
+    WrapperSharedPtr GetWrapper();
 
     // Foreground and background void.
     T _data[2];
@@ -165,7 +167,7 @@ private:
     WrapperTLSId _wrapper_key;
 
     // All thread-local instances.
-    std::vector<Wrapper*> _wrappers;
+    std::vector<WrapperWeakPtr> _wrappers;
 
     // Sequence access to _wrappers.
     pthread_mutex_t _wrappers_mutex{};
@@ -195,18 +197,22 @@ class DoublyBufferedData<T, TLS, 
AllowBthreadSuspended>::WrapperTLSGroup {
 public:
     const static size_t RAW_BLOCK_SIZE = 4096;
     const static size_t ELEMENTS_PER_BLOCK =
-        RAW_BLOCK_SIZE / sizeof(Wrapper) > 0 ? RAW_BLOCK_SIZE / 
sizeof(Wrapper) : 1;
+        RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) > 0 ?
+            RAW_BLOCK_SIZE / sizeof(WrapperSharedPtr) : 1;
 
     struct BAIDU_CACHELINE_ALIGNMENT ThreadBlock {
-        inline DoublyBufferedData::Wrapper* at(size_t offset) {
-            return _data + offset;
+        WrapperSharedPtr at(size_t offset) {
+            if (NULL == _data[offset]) {
+                _data[offset] = std::make_shared<Wrapper>();
+            }
+            return _data[offset];
         };
 
     private:
-        DoublyBufferedData::Wrapper _data[ELEMENTS_PER_BLOCK];
+        WrapperSharedPtr _data[ELEMENTS_PER_BLOCK];
     };
 
-    inline static WrapperTLSId key_create() {
+    static WrapperTLSId key_create() {
         BAIDU_SCOPED_LOCK(_s_mutex);
         WrapperTLSId id = 0;
         if (!_get_free_ids().empty()) {
@@ -218,7 +224,7 @@ public:
         return id;
     }
 
-    inline static int key_delete(WrapperTLSId id) {
+    static int key_delete(WrapperTLSId id) {
         BAIDU_SCOPED_LOCK(_s_mutex);
         if (id < 0 || id >= _s_id) {
             errno = EINVAL;
@@ -228,17 +234,13 @@ public:
         return 0;
     }
 
-    inline static DoublyBufferedData::Wrapper* 
get_or_create_tls_data(WrapperTLSId id) {
+    static WrapperSharedPtr get_or_create_tls_data(WrapperTLSId id) {
         if (BAIDU_UNLIKELY(id < 0)) {
             CHECK(false) << "Invalid id=" << id;
             return NULL;
         }
         if (_s_tls_blocks == NULL) {
-            _s_tls_blocks = new (std::nothrow) std::vector<ThreadBlock*>;
-            if (BAIDU_UNLIKELY(_s_tls_blocks == NULL)) {
-                LOG(FATAL) << "Fail to create vector, " << berror();
-                return NULL;
-            }
+            _s_tls_blocks = new std::vector<ThreadBlock*>;
             butil::thread_atexit(_destroy_tls_blocks);
         }
         const size_t block_id = (size_t)id / ELEMENTS_PER_BLOCK;
@@ -248,12 +250,8 @@ public:
         }
         ThreadBlock* tb = (*_s_tls_blocks)[block_id];
         if (tb == NULL) {
-            ThreadBlock* new_block = new (std::nothrow) ThreadBlock;
-            if (BAIDU_UNLIKELY(new_block == NULL)) {
-                return NULL;
-            }
-            tb = new_block;
-            (*_s_tls_blocks)[block_id] = new_block;
+            tb = new ThreadBlock;
+            (*_s_tls_blocks)[block_id] = tb;
         }
         return tb->at(id - block_id * ELEMENTS_PER_BLOCK);
     }
@@ -316,10 +314,6 @@ public:
     }
     
     ~Wrapper() {
-        if (_control != NULL) {
-            _control->RemoveWrapper(this);
-        }
-
         if (AllowBthreadSuspended) {
             WaitReadDone(0);
             WaitReadDone(1);
@@ -406,9 +400,9 @@ private:
 
 // Called when thread initializes thread-local wrapper.
 template <typename T, typename TLS, bool AllowBthreadSuspended>
-typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper*
-DoublyBufferedData<T, TLS, AllowBthreadSuspended>::AddWrapper(
-        typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* 
w) {
+typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::WrapperSharedPtr
+DoublyBufferedData<T, TLS, AllowBthreadSuspended>::GetWrapper() {
+    WrapperSharedPtr w = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
     if (NULL == w) {
         return NULL;
     }
@@ -423,29 +417,19 @@ DoublyBufferedData<T, TLS, 
AllowBthreadSuspended>::AddWrapper(
         w->_control = this;
         BAIDU_SCOPED_LOCK(_wrappers_mutex);
         _wrappers.push_back(w);
+        // The chance to remove expired weak_ptr.
+        _wrappers.erase(
+            std::remove_if(_wrappers.begin(), _wrappers.end(),
+                           [](const WrapperWeakPtr& w) {
+                    return w.expired();
+                }),
+            _wrappers.end());
     } catch (std::exception& e) {
         return NULL;
     }
     return w;
 }
 
-// Called when thread quits.
-template <typename T, typename TLS, bool AllowBthreadSuspended>
-void DoublyBufferedData<T, TLS, AllowBthreadSuspended>::RemoveWrapper(
-    typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Wrapper* w) {
-    if (NULL == w) {
-        return;
-    }
-    BAIDU_SCOPED_LOCK(_wrappers_mutex);
-    for (size_t i = 0; i < _wrappers.size(); ++i) {
-        if (_wrappers[i] == w) {
-            _wrappers[i] = _wrappers.back();
-            _wrappers.pop_back();
-            return;
-        }
-    }
-}
-
 template <typename T, typename TLS, bool AllowBthreadSuspended>
 DoublyBufferedData<T, TLS, AllowBthreadSuspended>::DoublyBufferedData()
     : _index(0)
@@ -474,7 +458,10 @@ DoublyBufferedData<T, TLS, 
AllowBthreadSuspended>::~DoublyBufferedData() {
     {
         BAIDU_SCOPED_LOCK(_wrappers_mutex);
         for (size_t i = 0; i < _wrappers.size(); ++i) {
-            _wrappers[i]->_control = NULL;  // hack: disable removal.
+            WrapperSharedPtr w = _wrappers[i].lock();
+            if (NULL != w) {
+                w->_control = NULL;  // hack: disable removal.
+            }
         }
         _wrappers.clear();
     }
@@ -487,29 +474,28 @@ DoublyBufferedData<T, TLS, 
AllowBthreadSuspended>::~DoublyBufferedData() {
 template <typename T, typename TLS, bool AllowBthreadSuspended>
 int DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Read(
     typename DoublyBufferedData<T, TLS, AllowBthreadSuspended>::ScopedPtr* 
ptr) {
-    Wrapper* p = WrapperTLSGroup::get_or_create_tls_data(_wrapper_key);
-    Wrapper* w = AddWrapper(p);
-    if (BAIDU_LIKELY(w != NULL)) {
-        if (AllowBthreadSuspended) {
-            // Use reference count instead of mutex to indicate read of
-            // foreground instance, so during the read process, there is
-            // no need to lock mutex and bthread is allowed to be suspended.
-            w->BeginRead();
-            int index = -1;
-            ptr->_data = UnsafeRead(index);
-            ptr->_index = index;
-            w->AddRef(index);
-            ptr->_w = w;
-            w->BeginReadRelease();
-        } else {
-            w->BeginRead();
-            ptr->_data = UnsafeRead();
-            ptr->_w = w;
-        }
+    WrapperSharedPtr w = GetWrapper();
+    if (BAIDU_UNLIKELY(w == NULL)) {
+        return -1;
+    }
 
-        return 0;
+    if (AllowBthreadSuspended) {
+        // Use reference count instead of mutex to indicate read of
+        // foreground instance, so during the read process, there is
+        // no need to lock mutex and bthread is allowed to be suspended.
+        w->BeginRead();
+        int index = -1;
+        ptr->_data = UnsafeRead(index);
+        ptr->_index = index;
+        w->AddRef(index);
+        ptr->_w = w;
+        w->BeginReadRelease();
+    } else {
+        w->BeginRead();
+        ptr->_data = UnsafeRead();
+        ptr->_w = w;
     }
-    return -1;
+    return 0;
 }
 
 template <typename T, typename TLS, bool AllowBthreadSuspended>
@@ -530,7 +516,7 @@ template <typename Fn, typename... Args>
 size_t DoublyBufferedData<T, TLS, AllowBthreadSuspended>::Modify(Fn&& fn, 
Args&&... args) {
     // _modify_mutex sequences modifications. Using a separate mutex rather
     // than _wrappers_mutex is to avoid blocking threads calling
-    // AddWrapper() or RemoveWrapper() too long. Most of the time, 
modifications
+    // GetWrapper() too long. Most of the time, modifications
     // are done by one thread, contention should be negligible.
     BAIDU_SCOPED_LOCK(_modify_mutex);
     int bg_index = !_index.load(butil::memory_order_relaxed);
@@ -552,14 +538,24 @@ size_t DoublyBufferedData<T, TLS, 
AllowBthreadSuspended>::Modify(Fn&& fn, Args&&
     // read, they should see updated _index.
     {
         BAIDU_SCOPED_LOCK(_wrappers_mutex);
-        for (size_t i = 0; i < _wrappers.size(); ++i) {
-            // Wait read of old foreground instance done.
-            if (AllowBthreadSuspended) {
-                _wrappers[i]->WaitReadDone(bg_index);
-            } else {
-                _wrappers[i]->WaitReadDone();
-            }
-        }
+        // The chance to remove expired weak_ptr.
+        _wrappers.erase(
+            std::remove_if(_wrappers.begin(), _wrappers.end(),
+                           [bg_index](const WrapperWeakPtr& weak) {
+                WrapperSharedPtr w = weak.lock();
+                bool expired = NULL == w;
+                if (!expired) {
+                    // Notify all threads waiting for read done.
+                    if (AllowBthreadSuspended) {
+                        w->WaitReadDone(bg_index);
+                    } else {
+                        w->WaitReadDone();
+                    }
+               }
+                // Remove expired weak_ptr.
+                return expired;
+            }),
+            _wrappers.end());
     }
 
     const size_t ret2 = fn(_data[bg_index], std::forward<Args>(args)...);
diff --git a/src/json2pb/pb_to_json.cpp b/src/json2pb/pb_to_json.cpp
index c23ccdf7..e37cc87d 100644
--- a/src/json2pb/pb_to_json.cpp
+++ b/src/json2pb/pb_to_json.cpp
@@ -336,14 +336,14 @@ bool ProtoMessageToJson(const google::protobuf::Message& 
message,
 }
 
 bool ProtoMessageToJson(const google::protobuf::Message& message,
-                        google::protobuf::io::ZeroCopyOutputStream *stream,
+                        google::protobuf::io::ZeroCopyOutputStream* stream,
                         const Pb2JsonOptions& options, std::string* error) {
     json2pb::ZeroCopyStreamWriter wrapper(stream);
     return json2pb::ProtoMessageToJsonStream(message, options, wrapper, error);
 }
 
 bool ProtoMessageToJson(const google::protobuf::Message& message,
-                        google::protobuf::io::ZeroCopyOutputStream *stream,
+                        google::protobuf::io::ZeroCopyOutputStream* stream,
                         std::string* error) {
     return ProtoMessageToJson(message, stream, Pb2JsonOptions(), error);
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@brpc.apache.org
For additional commands, e-mail: dev-h...@brpc.apache.org

Reply via email to