This is an automated email from the ASF dual-hosted git repository.

wwbmmm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/brpc.git


The following commit(s) were added to refs/heads/master by this push:
     new 02fd47ca Support thread local object iteration (#2632)
02fd47ca is described below

commit 02fd47caa000f4b9abd5c03930f2b190cc285640
Author: Bright Chen <[email protected]>
AuthorDate: Mon Jun 3 16:51:11 2024 +0800

    Support thread local object iteration (#2632)
---
 src/butil/thread_key.cpp     | 81 ++++++++++++++++----------------------------
 src/butil/thread_key.h       | 28 ++++++++++++---
 test/endpoint_unittest.cpp   |  4 +--
 test/thread_key_unittest.cpp | 44 ++++++++++++++++++++++--
 4 files changed, 96 insertions(+), 61 deletions(-)

diff --git a/src/butil/thread_key.cpp b/src/butil/thread_key.cpp
index 02bcd586..3bf4bb0f 100644
--- a/src/butil/thread_key.cpp
+++ b/src/butil/thread_key.cpp
@@ -38,7 +38,7 @@ pthread_mutex_t g_thread_key_mutex = 
PTHREAD_MUTEX_INITIALIZER;
 static size_t g_id = 0;
 static std::deque<size_t>* g_free_ids = NULL;
 static std::vector<ThreadKeyInfo>* g_thread_keys = NULL;
-static __thread std::vector<ThreadKeyTLS>* g_tls_data = NULL;
+static __thread std::vector<ThreadKeyTLS>* thread_key_tls_data = NULL;
 
 ThreadKey& ThreadKey::operator=(ThreadKey&& other) noexcept {
     if (this == &other) {
@@ -56,58 +56,42 @@ bool ThreadKey::Valid() const {
 }
 
 static void DestroyTlsData() {
-    if (!g_tls_data) {
+    if (!thread_key_tls_data) {
         return;
     }
     std::vector<ThreadKeyInfo> dummy_keys;
     {
         BAIDU_SCOPED_LOCK(g_thread_key_mutex);
-        if (BAIDU_LIKELY(g_thread_keys)) {
-            dummy_keys.insert(dummy_keys.end(), g_thread_keys->begin(), 
g_thread_keys->end());
-        }
+        dummy_keys.insert(dummy_keys.end(),
+                          g_thread_keys->begin(),
+                          g_thread_keys->end());
     }
-    for (size_t i = 0; i < g_tls_data->size(); ++i) {
+    for (size_t i = 0; i < thread_key_tls_data->size(); ++i) {
         if (!KEY_UNUSED(dummy_keys[i].seq) && dummy_keys[i].dtor) {
-            dummy_keys[i].dtor((*g_tls_data)[i].data);
+            dummy_keys[i].dtor((*thread_key_tls_data)[i].data);
         }
     }
-    delete g_tls_data;
-    g_tls_data = NULL;
-}
-
-static std::deque<size_t>* GetGlobalFreeIds() {
-    if (BAIDU_UNLIKELY(!g_free_ids)) {
-        g_free_ids = new (std::nothrow) std::deque<size_t>();
-        if (BAIDU_UNLIKELY(!g_free_ids)) {
-            abort();
-        }
-    }
-
-    return g_free_ids;
+    delete thread_key_tls_data;
+    thread_key_tls_data = NULL;
 }
 
 int thread_key_create(ThreadKey& thread_key, DtorFunction dtor) {
     BAIDU_SCOPED_LOCK(g_thread_key_mutex);
-    size_t id;
-    auto free_ids = GetGlobalFreeIds();
-    if (!free_ids) {
-        return ENOMEM;
+    if (BAIDU_UNLIKELY(!g_free_ids)) {
+        g_free_ids = new std::deque<size_t>;
     }
-
-    if (!free_ids->empty()) {
-        id = free_ids->back();
-        free_ids->pop_back();
+    size_t id;
+    if (!g_free_ids->empty()) {
+        id = g_free_ids->back();
+        g_free_ids->pop_back();
     } else {
         if (g_id >= ThreadKey::InvalidID) {
             // No more available ids.
             return EAGAIN;
         }
         id = g_id++;
-        if(BAIDU_UNLIKELY(!g_thread_keys)) {
-            g_thread_keys = new (std::nothrow) std::vector<ThreadKeyInfo>;
-            if(BAIDU_UNLIKELY(!g_thread_keys)) {
-                return ENOMEM;
-            }
+        if (BAIDU_UNLIKELY(!g_thread_keys)) {
+            g_thread_keys = new std::vector<ThreadKeyInfo>;
             g_thread_keys->reserve(THREAD_KEY_RESERVE);
         }
         g_thread_keys->resize(id + 1);
@@ -136,14 +120,10 @@ int thread_key_delete(ThreadKey& thread_key) {
         return EINVAL;
     }
 
-    if (BAIDU_UNLIKELY(!GetGlobalFreeIds())) {
-        return ENOMEM;
-    }
-
     ++((*g_thread_keys)[id].seq);
     // Collect the usable key id for reuse.
     if (KEY_USABLE((*g_thread_keys)[id].seq)) {
-        GetGlobalFreeIds()->push_back(id);
+        g_free_ids->push_back(id);
     }
     thread_key.Reset();
 
@@ -156,22 +136,19 @@ int thread_setspecific(ThreadKey& thread_key, void* data) 
{
     }
     size_t id = thread_key._id;
     size_t seq = thread_key._seq;
-    if (BAIDU_UNLIKELY(!g_tls_data)) {
-        g_tls_data = new (std::nothrow) std::vector<ThreadKeyTLS>;
-        if (BAIDU_UNLIKELY(!g_tls_data)) {
-            return ENOMEM;
-        }
-        g_tls_data->reserve(THREAD_KEY_RESERVE);
+    if (BAIDU_UNLIKELY(!thread_key_tls_data)) {
+        thread_key_tls_data = new std::vector<ThreadKeyTLS>;
+        thread_key_tls_data->reserve(THREAD_KEY_RESERVE);
         // Register the destructor of tls_data in this thread.
         butil::thread_atexit(DestroyTlsData);
     }
 
-    if (id >= g_tls_data->size()) {
-        g_tls_data->resize(id + 1);
+    if (id >= thread_key_tls_data->size()) {
+        thread_key_tls_data->resize(id + 1);
     }
 
-    (*g_tls_data)[id].seq  = seq;
-    (*g_tls_data)[id].data = data;
+    (*thread_key_tls_data)[id].seq  = seq;
+    (*thread_key_tls_data)[id].data = data;
 
     return 0;
 }
@@ -182,13 +159,13 @@ void* thread_getspecific(ThreadKey& thread_key) {
     }
     size_t id = thread_key._id;
     size_t seq = thread_key._seq;
-    if (BAIDU_UNLIKELY(!g_tls_data ||
-                       id >= g_tls_data->size() ||
-                       (*g_tls_data)[id].seq != seq)){
+    if (BAIDU_UNLIKELY(!thread_key_tls_data ||
+                       id >= thread_key_tls_data->size() ||
+                       (*thread_key_tls_data)[id].seq != seq)){
         return NULL;
     }
 
-    return (*g_tls_data)[id].data;
+    return (*thread_key_tls_data)[id].data;
 }
 
 } // namespace butil
\ No newline at end of file
diff --git a/src/butil/thread_key.h b/src/butil/thread_key.h
index f8d8f0e4..e95fa2fa 100644
--- a/src/butil/thread_key.h
+++ b/src/butil/thread_key.h
@@ -23,6 +23,7 @@
 #include <stdlib.h>
 #include <vector>
 #include "butil/scoped_lock.h"
+#include "butil/type_traits.h"
 
 namespace butil {
 
@@ -38,7 +39,7 @@ public:
     static constexpr size_t InvalidID = std::numeric_limits<size_t>::max();
     static constexpr size_t InitSeq = 0;
 
-    constexpr ThreadKey() :_id(InvalidID), _seq(InitSeq) {}
+    constexpr ThreadKey() : _id(InvalidID), _seq(InitSeq) {}
 
     ~ThreadKey() {
         Reset();
@@ -62,7 +63,7 @@ public:
         _seq = InitSeq;
     }
 
-    private:
+private:
     size_t _id; // Key id.
     // Sequence number form g_thread_keys set in thread_key_create.
     size_t _seq;
@@ -111,6 +112,20 @@ public:
 
     T& operator*() const { return *get(); }
 
+    // Iterate through all thread local objects.
+    // Callback, which must accept Args params and return void,
+    // will be called under a thread lock.
+    template <typename Callback>
+    void for_each(Callback&& callback) {
+        BAIDU_CASSERT(
+            (is_result_void<Callback, T*>::value),
+            "Callback must accept Args params and return void");
+        BAIDU_SCOPED_LOCK(_mutex);
+        for (auto ptr : ptrs) {
+            callback(ptr);
+        }
+    }
+
     void reset(T* ptr);
 
     void reset() {
@@ -177,6 +192,9 @@ T* ThreadLocal<T>::get() {
 template <typename T>
 void ThreadLocal<T>::reset(T* ptr) {
     T* old_ptr = get();
+    if (ptr == old_ptr) {
+        return;
+    }
     if (thread_setspecific(_key, ptr) != 0) {
         return;
     }
@@ -187,9 +205,9 @@ void ThreadLocal<T>::reset(T* ptr) {
         }
         // Remove and delete old_ptr.
         if (old_ptr) {
-            auto iter = std::find(ptrs.begin(), ptrs.end(), old_ptr);
-            if (iter!=ptrs.end()) {
-                ptrs.erase(iter);
+            auto iter = std::remove(ptrs.begin(), ptrs.end(), old_ptr);
+            if (iter != ptrs.end()) {
+                ptrs.erase(iter, ptrs.end());
             }
             DefaultDtor(old_ptr);
         }
diff --git a/test/endpoint_unittest.cpp b/test/endpoint_unittest.cpp
index fcb23a7b..e0da1af1 100644
--- a/test/endpoint_unittest.cpp
+++ b/test/endpoint_unittest.cpp
@@ -486,11 +486,11 @@ TEST(EndPointTest, tcp_connect) {
     ASSERT_EQ(0, butil::hostname2endpoint(g_hostname, 80, &ep));
     {
         butil::fd_guard sockfd(butil::tcp_connect(ep, NULL));
-        ASSERT_LE(0, sockfd);
+        ASSERT_LE(0, sockfd) << "errno=" << errno;
     }
     {
         butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1000));
-        ASSERT_LE(0, sockfd);
+        ASSERT_LE(0, sockfd) << "errno=" << errno;
     }
     {
         butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1));
diff --git a/test/thread_key_unittest.cpp b/test/thread_key_unittest.cpp
index a4609aed..adbeb024 100644
--- a/test/thread_key_unittest.cpp
+++ b/test/thread_key_unittest.cpp
@@ -104,7 +104,7 @@ TEST(ThreadLocalTest, thread_key_seq) {
     }
 }
 
-void* THreadKeyCreateAndDeleteFunc(void* arg) {
+void* THreadKeyCreateAndDeleteFunc(void*) {
     while (!g_stopped) {
         ThreadKey key;
         EXPECT_EQ(0, butil::thread_key_create(key, NULL));
@@ -162,7 +162,7 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
         ASSERT_EQ(0, pthread_create(&threads[i], NULL, ThreadLocalFunc, 
&args));
     }
 
-    sleep(5);
+    sleep(2);
     g_stopped = true;
     for (const auto& thread : threads) {
         pthread_join(thread, NULL);
@@ -172,6 +172,46 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
     }
 }
 
+butil::atomic<int> g_counter(0);
+
+void* ThreadLocalForEachFunc(void* arg) {
+    auto counter = static_cast<ThreadLocal<butil::atomic<int>>*>(arg);
+    auto local_counter = counter->get();
+    EXPECT_NE(nullptr, local_counter);
+    while (!g_stopped) {
+        local_counter->fetch_add(1, butil::memory_order_relaxed);
+        g_counter.fetch_add(1, butil::memory_order_relaxed);
+        if (butil::fast_rand_less_than(100) + 1 > 80) {
+            local_counter = new butil::atomic<int>(
+                local_counter->load(butil::memory_order_relaxed));
+            counter->reset(local_counter);
+        }
+    }
+    return NULL;
+}
+
+TEST(ThreadLocalTest, thread_local_for_each) {
+    g_stopped = false;
+    ThreadLocal<butil::atomic<int>> counter(false);
+    const int thread_num = 8;
+    pthread_t threads[thread_num];
+    for (int i = 0; i < thread_num; ++i) {
+        ASSERT_EQ(0, pthread_create(
+            &threads[i], NULL, ThreadLocalForEachFunc, &counter));
+    }
+
+    sleep(2);
+    g_stopped = true;
+    for (const auto& thread : threads) {
+        pthread_join(thread, NULL);
+    }
+    int count = 0;
+    counter.for_each([&count](butil::atomic<int>* c) {
+        count += c->load(butil::memory_order_relaxed);
+    });
+    ASSERT_EQ(count, g_counter.load(butil::memory_order_relaxed));
+}
+
 struct BAIDU_CACHELINE_ALIGNMENT ThreadKeyArg {
     std::vector<ThreadKey*> thread_keys;
     bool ready_delete = false;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to