This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 770c07431 [CELEBORN-2169][CIP-14] Support ConcurrentHashMap and
refactor reducerFileGroupInfos
770c07431 is described below
commit 770c07431b241b2fcd802c15f10e25aa6f648703
Author: HolyLow <[email protected]>
AuthorDate: Fri Oct 17 17:55:38 2025 +0800
[CELEBORN-2169][CIP-14] Support ConcurrentHashMap and refactor
reducerFileGroupInfos
### What changes were proposed in this pull request?
This PR supports ConcurrentHashMap data structure, and refactor existing
reducerFileGroupInfos code to use it.
### Why are the changes needed?
The ConcurrentHashMap will be widely used in WriterClient, and it would
simplify the code much by providing synchronization semantic.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and E2E test.
Closes #3499 from HolyLow/issue/celeborn-2169-support-concurrent-hash-map.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/ShuffleClient.cpp | 50 +++----
cpp/celeborn/client/ShuffleClient.h | 9 +-
cpp/celeborn/utils/CelebornUtils.h | 70 ++++++++++
cpp/celeborn/utils/tests/CMakeLists.txt | 2 +-
cpp/celeborn/utils/tests/CelebornUtilsTest.cpp | 186 +++++++++++++++++++++++++
5 files changed, 282 insertions(+), 35 deletions(-)
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index 92279e4f3..6e29db909 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -65,20 +65,21 @@ std::unique_ptr<CelebornInputStream>
ShuffleClientImpl::readPartition(
int startMapIndex,
int endMapIndex,
bool needCompression) {
- const auto& reducerFileGroupInfo = getReducerFileGroupInfo(shuffleId);
+ const auto reducerFileGroupInfo = getReducerFileGroupInfo(shuffleId);
+ CELEBORN_CHECK_NOT_NULL(reducerFileGroupInfo);
std::string shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
std::vector<std::shared_ptr<const protocol::PartitionLocation>> locations;
- if (!reducerFileGroupInfo.fileGroups.empty() &&
- reducerFileGroupInfo.fileGroups.count(partitionId)) {
+ if (!reducerFileGroupInfo->fileGroups.empty() &&
+ reducerFileGroupInfo->fileGroups.count(partitionId)) {
locations = std::move(utils::toVector(
- reducerFileGroupInfo.fileGroups.find(partitionId)->second));
+ reducerFileGroupInfo->fileGroups.find(partitionId)->second));
}
return std::make_unique<CelebornInputStream>(
shuffleKey,
conf_,
clientFactory_,
std::move(locations),
- reducerFileGroupInfo.attempts,
+ reducerFileGroupInfo->attempts,
attemptNumber,
startMapIndex,
endMapIndex,
@@ -98,13 +99,10 @@ void ShuffleClientImpl::updateReducerFileGroup(int
shuffleId) {
switch (reducerFileGroupInfo->status) {
case protocol::SUCCESS: {
VLOG(1) << "success to get reducerFileGroupInfo, shuffleId " <<
shuffleId;
- std::lock_guard<std::mutex> lock(mutex_);
- if (reducerFileGroupInfos_.count(shuffleId) > 0) {
- VLOG(1) << "reducerFileGroupInfo for shuffleId" << shuffleId
- << " already exists, ignored";
- return;
- }
- reducerFileGroupInfos_[shuffleId] = std::move(reducerFileGroupInfo);
+ reducerFileGroupInfos_.set(
+ shuffleId,
+ std::shared_ptr<protocol::GetReducerFileGroupResponse>(
+ reducerFileGroupInfo.release()));
return;
}
case protocol::SHUFFLE_NOT_REGISTERED: {
@@ -112,13 +110,10 @@ void ShuffleClientImpl::updateReducerFileGroup(int
shuffleId) {
// shuffle.
LOG(WARNING) << "shuffleId " << shuffleId
<< " is not registered when get reducerFileGroupInfo";
- std::lock_guard<std::mutex> lock(mutex_);
- if (reducerFileGroupInfos_.count(shuffleId) > 0) {
- VLOG(1) << "reducerFileGroupInfo for shuffleId" << shuffleId
- << " already exists, ignored";
- return;
- }
- reducerFileGroupInfos_[shuffleId] = std::move(reducerFileGroupInfo);
+ reducerFileGroupInfos_.set(
+ shuffleId,
+ std::shared_ptr<protocol::GetReducerFileGroupResponse>(
+ reducerFileGroupInfo.release()));
return;
}
case protocol::STAGE_END_TIME_OUT:
@@ -140,21 +135,16 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
return true;
}
-protocol::GetReducerFileGroupResponse&
+std::shared_ptr<protocol::GetReducerFileGroupResponse>
ShuffleClientImpl::getReducerFileGroupInfo(int shuffleId) {
- {
- std::lock_guard<std::mutex> lock(mutex_);
- auto iter = reducerFileGroupInfos_.find(shuffleId);
- if (iter != reducerFileGroupInfos_.end()) {
- return *iter->second;
- }
+ auto reducerFileGroupInfoOptional = reducerFileGroupInfos_.get(shuffleId);
+ if (reducerFileGroupInfoOptional.has_value()) {
+ return reducerFileGroupInfoOptional.value();
}
updateReducerFileGroup(shuffleId);
- {
- std::lock_guard<std::mutex> lock(mutex_);
- return *reducerFileGroupInfos_[shuffleId];
- }
+
+ return getReducerFileGroupInfo(shuffleId);
}
} // namespace client
} // namespace celeborn
diff --git a/cpp/celeborn/client/ShuffleClient.h
b/cpp/celeborn/client/ShuffleClient.h
index 284c7ade9..b56c60cf8 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -85,16 +85,17 @@ class ShuffleClientImpl : public ShuffleClient {
void shutdown() override {}
private:
- protocol::GetReducerFileGroupResponse& getReducerFileGroupInfo(int
shuffleId);
+ std::shared_ptr<protocol::GetReducerFileGroupResponse>
+ getReducerFileGroupInfo(int shuffleId);
const std::string appUniqueId_;
std::shared_ptr<const conf::CelebornConf> conf_;
std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
std::shared_ptr<network::TransportClientFactory> clientFactory_;
std::mutex mutex_;
- std::unordered_map<
- long,
- std::unique_ptr<protocol::GetReducerFileGroupResponse>>
+ utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<protocol::GetReducerFileGroupResponse>>
reducerFileGroupInfos_;
};
} // namespace client
diff --git a/cpp/celeborn/utils/CelebornUtils.h
b/cpp/celeborn/utils/CelebornUtils.h
index 03f8e1514..df42a09e8 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -17,6 +17,7 @@
#pragma once
+#include <folly/Synchronized.h>
#include <google/protobuf/io/coded_stream.h>
#include <charconv>
#include <chrono>
@@ -113,5 +114,74 @@ std::unique_ptr<T> parseProto(const uint8_t* bytes, int
len) {
return pbObj;
}
+template <typename TKey, typename TValue, typename THasher = std::hash<TKey>>
+class ConcurrentHashMap {
+ public:
+ std::optional<TValue> get(const TKey& key) {
+ // Explicitly declaring the return type helps type deduction.
+ return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
+ auto iter = map.find(key);
+ if (iter != map.end()) {
+ return iter->second;
+ }
+ return std::nullopt;
+ });
+ }
+
+ bool containsKey(const TKey& key) {
+ return synchronizedMap_.withLock([&](auto& map) {
+ auto iter = map.find(key);
+ if (iter != map.end()) {
+ return true;
+ }
+ return false;
+ });
+ }
+
+ TValue computeIfAbsent(const TKey& key, std::function<TValue()> compute) {
+ return synchronizedMap_.withLock([&](auto& map) {
+ auto iter = map.find(key);
+ if (iter != map.end()) {
+ return iter->second;
+ }
+ map[key] = compute();
+ return map[key];
+ });
+ }
+
+ void set(const TKey& key, TValue&& value) {
+ synchronizedMap_.withLock([&](auto& map) { map[key] = std::move(value); });
+ }
+
+ void set(const TKey& key, const TValue& value) {
+ synchronizedMap_.withLock([&](auto& map) { map[key] = value; });
+ }
+
+ size_t size() const {
+ return synchronizedMap_.lock()->size();
+ }
+
+ std::optional<TValue> erase(const TKey& key) {
+ // Explicitly declaring the return type helps type deduction.
+ return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
+ auto iter = map.find(key);
+ if (iter != map.end()) {
+ auto result = std::move(iter->second);
+ map.erase(key);
+ return std::move(result);
+ }
+ return std::nullopt;
+ });
+ }
+
+ void clear() {
+ synchronizedMap_.lock()->clear();
+ }
+
+ private:
+ folly::Synchronized<std::unordered_map<TKey, TValue, THasher>, std::mutex>
+ synchronizedMap_;
+};
+
} // namespace utils
} // namespace celeborn
diff --git a/cpp/celeborn/utils/tests/CMakeLists.txt
b/cpp/celeborn/utils/tests/CMakeLists.txt
index a820b4ac1..b54b70c43 100644
--- a/cpp/celeborn/utils/tests/CMakeLists.txt
+++ b/cpp/celeborn/utils/tests/CMakeLists.txt
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-add_executable(celeborn_utils_test ExceptionTest.cpp)
+add_executable(celeborn_utils_test ExceptionTest.cpp CelebornUtilsTest.cpp)
add_test(NAME celeborn_utils_test COMMAND celeborn_utils_test)
diff --git a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
new file mode 100644
index 000000000..53f29cb62
--- /dev/null
+++ b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "celeborn/utils/CelebornUtils.h"
+
+using namespace celeborn::utils;
+
+class CelebornUtilsTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ map_ = std::make_unique<ConcurrentHashMap<std::string, int>>();
+ }
+
+ std::unique_ptr<ConcurrentHashMap<std::string, int>> map_;
+};
+
+TEST_F(CelebornUtilsTest, mapBasicInsertAndRetrieve) {
+ map_->set("apple", 10);
+ auto result = map_->get("apple");
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(10, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapUpdateExistingKey) {
+ map_->set("apple", 10);
+ map_->set("apple", 20);
+ auto result = map_->get("apple");
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(20, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapComputeIfAbsent) {
+ map_->set("apple", 10);
+ map_->computeIfAbsent("apple", []() { return 20; });
+ auto result = map_->get("apple");
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(10, *result);
+
+ map_->computeIfAbsent("banana", []() { return 30; });
+ map_->computeIfAbsent("banana", []() { return 40; });
+ result = map_->get("banana");
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(30, *result);
+}
+
+TEST_F(CelebornUtilsTest, mapRemoveKey) {
+ map_->set("banana", 30);
+ map_->erase("banana");
+ auto result = map_->get("banana");
+ EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(CelebornUtilsTest, mapNonExistentKey) {
+ auto result = map_->get("mango");
+ EXPECT_FALSE(result.has_value());
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentInserts) {
+ constexpr int NUM_THREADS = 8;
+ constexpr int ITEMS_PER_THREAD = 100;
+ std::vector<std::thread> threads;
+
+ for (int i = 0; i < NUM_THREADS; ++i) {
+ threads.emplace_back([this, i] {
+ for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+ std::string key =
+ "thread" + std::to_string(i) + "-" + std::to_string(j);
+ map_->set(key, j);
+ }
+ });
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify all items were inserted
+ for (int i = 0; i < NUM_THREADS; ++i) {
+ for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+ std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
+ auto result = map_->get(key);
+ ASSERT_TRUE(result.has_value()) << "Missing key: " << key;
+ EXPECT_EQ(j, *result);
+ }
+ }
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentUpdates) {
+ constexpr int NUM_THREADS = 8;
+ std::vector<std::thread> threads;
+ std::atomic<bool> start{false};
+ // Initial value
+ map_->set("contended", 0);
+
+ for (int i = 0; i < NUM_THREADS; ++i) {
+ threads.emplace_back([this, &start, i] {
+ while (!start) { /* spin */
+ } // Wait for start signal
+
+ for (int j = 0; j < 100; ++j) {
+ map_->set("contended", i * 100 + j);
+ }
+ });
+ }
+
+ start = true;
+
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // Verify the final value is from the last writer
+ auto result = map_->get("contended");
+ ASSERT_TRUE(result.has_value());
+ // The exact value depends on thread scheduling, but it should be
+ // from one of the threads (between 0*100+99 and 7*100+99)
+ EXPECT_GE(*result, 99);
+ EXPECT_LE(*result, 799);
+}
+
+TEST_F(CelebornUtilsTest, mapConcurrentReadWrite) {
+ constexpr int NUM_WRITERS = 4;
+ constexpr int NUM_READERS = 4;
+ std::atomic<bool> running{true};
+ std::vector<std::thread> writers;
+ std::vector<std::thread> readers;
+
+ // Writers constantly update values
+ for (int i = 0; i < NUM_WRITERS; ++i) {
+ writers.emplace_back([this, i, &running] {
+ while (running) {
+ map_->set("key" + std::to_string(i), i);
+ }
+ });
+ }
+
+ // Readers constantly read values
+ for (int i = 0; i < NUM_READERS; ++i) {
+ readers.emplace_back([this, &running] {
+ while (running) {
+ for (int j = 0; j < NUM_WRITERS; ++j) {
+ auto result = map_->get("key" + std::to_string(j));
+ }
+ }
+ });
+ }
+
+ // Let them run for 500ms
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+ running = false;
+
+ for (auto& t : writers) {
+ t.join();
+ }
+
+ for (auto& t : readers) {
+ t.join();
+ }
+
+ // Verify final values are from writers
+ for (int i = 0; i < NUM_WRITERS; ++i) {
+ auto result = map_->get("key" + std::to_string(i));
+ ASSERT_TRUE(result.has_value());
+ EXPECT_EQ(i, *result);
+ }
+}