This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 770c07431 [CELEBORN-2169][CIP-14] Support ConcurrentHashMap and 
refactor reducerFileGroupInfos
770c07431 is described below

commit 770c07431b241b2fcd802c15f10e25aa6f648703
Author: HolyLow <[email protected]>
AuthorDate: Fri Oct 17 17:55:38 2025 +0800

    [CELEBORN-2169][CIP-14] Support ConcurrentHashMap and refactor 
reducerFileGroupInfos
    
    ### What changes were proposed in this pull request?
    This PR supports ConcurrentHashMap data structure, and refactor existing 
reducerFileGroupInfos code to use it.
    
    ### Why are the changes needed?
    The ConcurrentHashMap will be widely used in WriterClient, and it would 
simplify the code much by providing synchronization semantic.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and E2E test.
    
    Closes #3499 from HolyLow/issue/celeborn-2169-support-concurrent-hash-map.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/ShuffleClient.cpp          |  50 +++----
 cpp/celeborn/client/ShuffleClient.h            |   9 +-
 cpp/celeborn/utils/CelebornUtils.h             |  70 ++++++++++
 cpp/celeborn/utils/tests/CMakeLists.txt        |   2 +-
 cpp/celeborn/utils/tests/CelebornUtilsTest.cpp | 186 +++++++++++++++++++++++++
 5 files changed, 282 insertions(+), 35 deletions(-)

diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index 92279e4f3..6e29db909 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -65,20 +65,21 @@ std::unique_ptr<CelebornInputStream> 
ShuffleClientImpl::readPartition(
     int startMapIndex,
     int endMapIndex,
     bool needCompression) {
-  const auto& reducerFileGroupInfo = getReducerFileGroupInfo(shuffleId);
+  const auto reducerFileGroupInfo = getReducerFileGroupInfo(shuffleId);
+  CELEBORN_CHECK_NOT_NULL(reducerFileGroupInfo);
   std::string shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
   std::vector<std::shared_ptr<const protocol::PartitionLocation>> locations;
-  if (!reducerFileGroupInfo.fileGroups.empty() &&
-      reducerFileGroupInfo.fileGroups.count(partitionId)) {
+  if (!reducerFileGroupInfo->fileGroups.empty() &&
+      reducerFileGroupInfo->fileGroups.count(partitionId)) {
     locations = std::move(utils::toVector(
-        reducerFileGroupInfo.fileGroups.find(partitionId)->second));
+        reducerFileGroupInfo->fileGroups.find(partitionId)->second));
   }
   return std::make_unique<CelebornInputStream>(
       shuffleKey,
       conf_,
       clientFactory_,
       std::move(locations),
-      reducerFileGroupInfo.attempts,
+      reducerFileGroupInfo->attempts,
       attemptNumber,
       startMapIndex,
       endMapIndex,
@@ -98,13 +99,10 @@ void ShuffleClientImpl::updateReducerFileGroup(int 
shuffleId) {
   switch (reducerFileGroupInfo->status) {
     case protocol::SUCCESS: {
       VLOG(1) << "success to get reducerFileGroupInfo, shuffleId " << 
shuffleId;
-      std::lock_guard<std::mutex> lock(mutex_);
-      if (reducerFileGroupInfos_.count(shuffleId) > 0) {
-        VLOG(1) << "reducerFileGroupInfo for shuffleId" << shuffleId
-                << " already exists, ignored";
-        return;
-      }
-      reducerFileGroupInfos_[shuffleId] = std::move(reducerFileGroupInfo);
+      reducerFileGroupInfos_.set(
+          shuffleId,
+          std::shared_ptr<protocol::GetReducerFileGroupResponse>(
+              reducerFileGroupInfo.release()));
       return;
     }
     case protocol::SHUFFLE_NOT_REGISTERED: {
@@ -112,13 +110,10 @@ void ShuffleClientImpl::updateReducerFileGroup(int 
shuffleId) {
       // shuffle.
       LOG(WARNING) << "shuffleId " << shuffleId
                    << " is not registered when get reducerFileGroupInfo";
-      std::lock_guard<std::mutex> lock(mutex_);
-      if (reducerFileGroupInfos_.count(shuffleId) > 0) {
-        VLOG(1) << "reducerFileGroupInfo for shuffleId" << shuffleId
-                << " already exists, ignored";
-        return;
-      }
-      reducerFileGroupInfos_[shuffleId] = std::move(reducerFileGroupInfo);
+      reducerFileGroupInfos_.set(
+          shuffleId,
+          std::shared_ptr<protocol::GetReducerFileGroupResponse>(
+              reducerFileGroupInfo.release()));
       return;
     }
     case protocol::STAGE_END_TIME_OUT:
@@ -140,21 +135,16 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
   return true;
 }
 
-protocol::GetReducerFileGroupResponse&
+std::shared_ptr<protocol::GetReducerFileGroupResponse>
 ShuffleClientImpl::getReducerFileGroupInfo(int shuffleId) {
-  {
-    std::lock_guard<std::mutex> lock(mutex_);
-    auto iter = reducerFileGroupInfos_.find(shuffleId);
-    if (iter != reducerFileGroupInfos_.end()) {
-      return *iter->second;
-    }
+  auto reducerFileGroupInfoOptional = reducerFileGroupInfos_.get(shuffleId);
+  if (reducerFileGroupInfoOptional.has_value()) {
+    return reducerFileGroupInfoOptional.value();
   }
 
   updateReducerFileGroup(shuffleId);
-  {
-    std::lock_guard<std::mutex> lock(mutex_);
-    return *reducerFileGroupInfos_[shuffleId];
-  }
+
+  return getReducerFileGroupInfo(shuffleId);
 }
 } // namespace client
 } // namespace celeborn
diff --git a/cpp/celeborn/client/ShuffleClient.h 
b/cpp/celeborn/client/ShuffleClient.h
index 284c7ade9..b56c60cf8 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -85,16 +85,17 @@ class ShuffleClientImpl : public ShuffleClient {
   void shutdown() override {}
 
  private:
-  protocol::GetReducerFileGroupResponse& getReducerFileGroupInfo(int 
shuffleId);
+  std::shared_ptr<protocol::GetReducerFileGroupResponse>
+  getReducerFileGroupInfo(int shuffleId);
 
   const std::string appUniqueId_;
   std::shared_ptr<const conf::CelebornConf> conf_;
   std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
   std::shared_ptr<network::TransportClientFactory> clientFactory_;
   std::mutex mutex_;
-  std::unordered_map<
-      long,
-      std::unique_ptr<protocol::GetReducerFileGroupResponse>>
+  utils::ConcurrentHashMap<
+      int,
+      std::shared_ptr<protocol::GetReducerFileGroupResponse>>
       reducerFileGroupInfos_;
 };
 } // namespace client
diff --git a/cpp/celeborn/utils/CelebornUtils.h 
b/cpp/celeborn/utils/CelebornUtils.h
index 03f8e1514..df42a09e8 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <folly/Synchronized.h>
 #include <google/protobuf/io/coded_stream.h>
 #include <charconv>
 #include <chrono>
@@ -113,5 +114,74 @@ std::unique_ptr<T> parseProto(const uint8_t* bytes, int 
len) {
   return pbObj;
 }
 
+template <typename TKey, typename TValue, typename THasher = std::hash<TKey>>
+class ConcurrentHashMap {
+ public:
+  std::optional<TValue> get(const TKey& key) {
+    // Explicitly declaring the return type helps type deduction.
+    return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
+      auto iter = map.find(key);
+      if (iter != map.end()) {
+        return iter->second;
+      }
+      return std::nullopt;
+    });
+  }
+
+  bool containsKey(const TKey& key) {
+    return synchronizedMap_.withLock([&](auto& map) {
+      auto iter = map.find(key);
+      if (iter != map.end()) {
+        return true;
+      }
+      return false;
+    });
+  }
+
+  TValue computeIfAbsent(const TKey& key, std::function<TValue()> compute) {
+    return synchronizedMap_.withLock([&](auto& map) {
+      auto iter = map.find(key);
+      if (iter != map.end()) {
+        return iter->second;
+      }
+      map[key] = compute();
+      return map[key];
+    });
+  }
+
+  void set(const TKey& key, TValue&& value) {
+    synchronizedMap_.withLock([&](auto& map) { map[key] = std::move(value); });
+  }
+
+  void set(const TKey& key, const TValue& value) {
+    synchronizedMap_.withLock([&](auto& map) { map[key] = value; });
+  }
+
+  size_t size() const {
+    return synchronizedMap_.lock()->size();
+  }
+
+  std::optional<TValue> erase(const TKey& key) {
+    // Explicitly declaring the return type helps type deduction.
+    return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
+      auto iter = map.find(key);
+      if (iter != map.end()) {
+        auto result = std::move(iter->second);
+        map.erase(key);
+        return std::move(result);
+      }
+      return std::nullopt;
+    });
+  }
+
+  void clear() {
+    synchronizedMap_.lock()->clear();
+  }
+
+ private:
+  folly::Synchronized<std::unordered_map<TKey, TValue, THasher>, std::mutex>
+      synchronizedMap_;
+};
+
 } // namespace utils
 } // namespace celeborn
diff --git a/cpp/celeborn/utils/tests/CMakeLists.txt 
b/cpp/celeborn/utils/tests/CMakeLists.txt
index a820b4ac1..b54b70c43 100644
--- a/cpp/celeborn/utils/tests/CMakeLists.txt
+++ b/cpp/celeborn/utils/tests/CMakeLists.txt
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-add_executable(celeborn_utils_test ExceptionTest.cpp)
+add_executable(celeborn_utils_test ExceptionTest.cpp CelebornUtilsTest.cpp)
 
 add_test(NAME celeborn_utils_test COMMAND celeborn_utils_test)
 
diff --git a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp 
b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
new file mode 100644
index 000000000..53f29cb62
--- /dev/null
+++ b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "celeborn/utils/CelebornUtils.h"
+
+using namespace celeborn::utils;
+
+class CelebornUtilsTest : public testing::Test {
+ protected:
+  void SetUp() override {
+    map_ = std::make_unique<ConcurrentHashMap<std::string, int>>();
+  }
+
+  std::unique_ptr<ConcurrentHashMap<std::string, int>> map_;
+};
+
+TEST_F(CelebornUtilsTest, mapBasicInsertAndRetrieve) {
+  map_->set("apple", 10);
+  auto result = map_->get("apple");
+  ASSERT_TRUE(result.has_value());
+  EXPECT_EQ(10, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapUpdateExistingKey) {
+  map_->set("apple", 10);
+  map_->set("apple", 20);
+  auto result = map_->get("apple");
+  ASSERT_TRUE(result.has_value());
+  EXPECT_EQ(20, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapComputeIfAbsent) {
+  map_->set("apple", 10);
+  map_->computeIfAbsent("apple", []() { return 20; });
+  auto result = map_->get("apple");
+  ASSERT_TRUE(result.has_value());
+  EXPECT_EQ(10, *result);
+
+  map_->computeIfAbsent("banana", []() { return 30; });
+  map_->computeIfAbsent("banana", []() { return 40; });
+  result = map_->get("banana");
+  ASSERT_TRUE(result.has_value());
+  EXPECT_EQ(30, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapRemoveKey) {
+  map_->set("banana", 30);
+  map_->erase("banana");
+  auto result = map_->get("banana");
+  EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(CelebornUtilsTest, mapNonExistentKey) {
+  auto result = map_->get("mango");
+  EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentInserts) {
+  constexpr int NUM_THREADS = 8;
+  constexpr int ITEMS_PER_THREAD = 100;
+  std::vector<std::thread> threads;
+
+  for (int i = 0; i < NUM_THREADS; ++i) {
+    threads.emplace_back([this, i] {
+      for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+        std::string key =
+            "thread" + std::to_string(i) + "-" + std::to_string(j);
+        map_->set(key, j);
+      }
+    });
+  }
+
+  for (auto& t : threads) {
+    t.join();
+  }
+
+  // Verify all items were inserted
+  for (int i = 0; i < NUM_THREADS; ++i) {
+    for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+      std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
+      auto result = map_->get(key);
+      ASSERT_TRUE(result.has_value()) << "Missing key: " << key;
+      EXPECT_EQ(j, *result);
+    }
+  }
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentUpdates) {
+  constexpr int NUM_THREADS = 8;
+  std::vector<std::thread> threads;
+  std::atomic<bool> start{false};
+  // Initial value
+  map_->set("contended", 0);
+
+  for (int i = 0; i < NUM_THREADS; ++i) {
+    threads.emplace_back([this, &start, i] {
+      while (!start) { /* spin */
+      } // Wait for start signal
+
+      for (int j = 0; j < 100; ++j) {
+        map_->set("contended", i * 100 + j);
+      }
+    });
+  }
+
+  start = true;
+
+  for (auto& t : threads) {
+    t.join();
+  }
+
+  // Verify the final value is from the last writer
+  auto result = map_->get("contended");
+  ASSERT_TRUE(result.has_value());
+  // The exact value depends on thread scheduling, but it should be
+  // from one of the threads (between 0*100+99 and 7*100+99)
+  EXPECT_GE(*result, 99);
+  EXPECT_LE(*result, 799);
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentReadWrite) {
+  constexpr int NUM_WRITERS = 4;
+  constexpr int NUM_READERS = 4;
+  std::atomic<bool> running{true};
+  std::vector<std::thread> writers;
+  std::vector<std::thread> readers;
+
+  // Writers constantly update values
+  for (int i = 0; i < NUM_WRITERS; ++i) {
+    writers.emplace_back([this, i, &running] {
+      while (running) {
+        map_->set("key" + std::to_string(i), i);
+      }
+    });
+  }
+
+  // Readers constantly read values
+  for (int i = 0; i < NUM_READERS; ++i) {
+    readers.emplace_back([this, &running] {
+      while (running) {
+        for (int j = 0; j < NUM_WRITERS; ++j) {
+          auto result = map_->get("key" + std::to_string(j));
+        }
+      }
+    });
+  }
+
+  // Let them run for 500ms
+  std::this_thread::sleep_for(std::chrono::milliseconds(500));
+  running = false;
+
+  for (auto& t : writers) {
+    t.join();
+  }
+
+  for (auto& t : readers) {
+    t.join();
+  }
+
+  // Verify final values are from writers
+  for (int i = 0; i < NUM_WRITERS; ++i) {
+    auto result = map_->get("key" + std::to_string(i));
+    ASSERT_TRUE(result.has_value());
+    EXPECT_EQ(i, *result);
+  }
+}

Reply via email to