This is an automated email from the ASF dual-hosted git repository. penghui pushed a commit to branch branch-2.10 in repository https://gitbox.apache.org/repos/asf/pulsar.git
commit e0d2f2072447892615f08f30aaf76ed1210c8a4a Author: Yunze Xu <[email protected]> AuthorDate: Thu Mar 10 16:03:02 2022 +0800 [C++] Fix thread safety issue for multi topic consumer (#14380) * [C++] Fix thread safety issue for multi topic consumer **Motivation** In C++ client, if a consumer subscribes multiple topics, a `MultiTopicsConsumerImpl` object, which manages a vector of `ConsumerImpl`s (`consumers_` field), will be created. However, `consumers_` could be accessed by multiple threads, while no mutex is locked to protect the access to make it thread safe. **Modifications** - Add a `SynchronizedHashMap` class, which implements some thread safe methods of traverse, remove, find, clear operations. Since the `forEach` methods could call other methods, use the recursive mutex instead of the default mutex. - Add a related test `SynchronizedHashMapTest` to test the methods and the thread safety of `SynchronizedHashMap`. - Use `SynchronizedHashMap` as the type of `MultiTopicsConsumerImpl::consumers_`. * Add findFirstValueIf method * Remove unnecessary return value of forEach * Fix incorrect calls of forEachValue * Add missed header (cherry picked from commit f94eba942b9fb3d2c25b6f7a9e2c0885a194efa0) --- pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc | 166 +++++++++------------ pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h | 5 +- pulsar-client-cpp/lib/SynchronizedHashMap.h | 127 ++++++++++++++++ pulsar-client-cpp/tests/ConsumerTest.cc | 13 +- pulsar-client-cpp/tests/SynchronizedHashMapTest.cc | 125 ++++++++++++++++ 5 files changed, 335 insertions(+), 101 deletions(-) diff --git a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc index 4e31e64..0ae86d5 100644 --- a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc +++ b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc @@ -171,7 +171,7 @@ void MultiTopicsConsumerImpl::subscribeTopicPartitions(const Result result, consumer->getConsumerCreatedFuture().addListener(std::bind( &MultiTopicsConsumerImpl::handleSingleConsumerCreated, shared_from_this(), std::placeholders::_1, std::placeholders::_2, partitionsNeedCreate, topicSubResultPromise)); - consumers_.insert(std::make_pair(topicName->toString(), consumer)); + consumers_.emplace(topicName->toString(), consumer); LOG_DEBUG("Creating Consumer for - " << topicName << " - " << consumerStr_); consumer->start(); @@ -184,7 +184,7 @@ void MultiTopicsConsumerImpl::subscribeTopicPartitions(const Result result, &MultiTopicsConsumerImpl::handleSingleConsumerCreated, shared_from_this(), std::placeholders::_1, std::placeholders::_2, partitionsNeedCreate, topicSubResultPromise)); consumer->setPartitionIndex(i); - consumers_.insert(std::make_pair(topicPartitionName, consumer)); + consumers_.emplace(topicPartitionName, consumer); LOG_DEBUG("Creating Consumer for - " << topicPartitionName << " - " << consumerStr_); consumer->start(); } @@ -232,20 +232,19 @@ void MultiTopicsConsumerImpl::unsubscribeAsync(ResultCallback callback) { state_ = Closing; lock.unlock(); - if (consumers_.empty()) { + std::shared_ptr<std::atomic<int>> consumerUnsubed = std::make_shared<std::atomic<int>>(0); + auto self = shared_from_this(); + int numConsumers = 0; + consumers_.forEachValue( + [&numConsumers, &consumerUnsubed, &self, callback](const ConsumerImplPtr& consumer) { + numConsumers++; + consumer->unsubscribeAsync([self, consumerUnsubed, callback](Result result) { + self->handleUnsubscribedAsync(result, consumerUnsubed, callback); + }); + }); + if (numConsumers == 0) { // No need to unsubscribe, since the list matching the regex was empty callback(ResultOk); - return; - } - - std::shared_ptr<std::atomic<int>> consumerUnsubed = std::make_shared<std::atomic<int>>(0); - - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - (consumer->second) - ->unsubscribeAsync(std::bind(&MultiTopicsConsumerImpl::handleUnsubscribedAsync, - shared_from_this(), std::placeholders::_1, consumerUnsubed, - callback)); } } @@ -299,17 +298,17 @@ void MultiTopicsConsumerImpl::unsubscribeOneTopicAsync(const std::string& topic, for (int i = 0; i < numberPartitions; i++) { std::string topicPartitionName = topicName->getTopicPartitionName(i); - std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName); - - if (consumers_.end() == iterator) { + auto optConsumer = consumers_.find(topicPartitionName); + if (optConsumer.is_empty()) { LOG_ERROR("TopicsConsumer not subscribed on topicPartitionName: " << topicPartitionName); callback(ResultUnknownError); + continue; } - (iterator->second) - ->unsubscribeAsync(std::bind(&MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync, - shared_from_this(), std::placeholders::_1, consumerUnsubed, - numberPartitions, topicName, topicPartitionName, callback)); + optConsumer.value()->unsubscribeAsync( + std::bind(&MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync, shared_from_this(), + std::placeholders::_1, consumerUnsubed, numberPartitions, topicName, topicPartitionName, + callback)); } } @@ -326,10 +325,9 @@ void MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync( LOG_DEBUG("Successfully Unsubscribed one Consumer. topicPartitionName - " << topicPartitionName); - std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName); - if (consumers_.end() != iterator) { - iterator->second->pauseMessageListener(); - consumers_.erase(iterator); + auto optConsumer = consumers_.remove(topicPartitionName); + if (optConsumer.is_present()) { + optConsumer.value()->pauseMessageListener(); } if (consumerUnsubed->load() == numberPartitions) { @@ -363,7 +361,16 @@ void MultiTopicsConsumerImpl::closeAsync(ResultCallback callback) { setState(Closing); - if (consumers_.empty()) { + auto self = shared_from_this(); + int numConsumers = 0; + consumers_.forEach( + [&numConsumers, &self, callback](const std::string& name, const ConsumerImplPtr& consumer) { + numConsumers++; + consumer->closeAsync([self, name, callback](Result result) { + self->handleSingleConsumerClose(result, name, callback); + }); + }); + if (numConsumers == 0) { LOG_DEBUG("TopicsConsumer have no consumers to close " << " topic" << topic_ << " subscription - " << subscriptionName_); setState(Closed); @@ -373,27 +380,13 @@ void MultiTopicsConsumerImpl::closeAsync(ResultCallback callback) { return; } - // close successfully subscribed consumers - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - std::string topicPartitionName = consumer->first; - ConsumerImplPtr consumerPtr = consumer->second; - - consumerPtr->closeAsync(std::bind(&MultiTopicsConsumerImpl::handleSingleConsumerClose, - shared_from_this(), std::placeholders::_1, topicPartitionName, - callback)); - } - // fail pending recieve failPendingReceiveCallback(); } -void MultiTopicsConsumerImpl::handleSingleConsumerClose(Result result, std::string& topicPartitionName, +void MultiTopicsConsumerImpl::handleSingleConsumerClose(Result result, std::string topicPartitionName, CloseCallback callback) { - std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName); - if (consumers_.end() != iterator) { - consumers_.erase(iterator); - } + consumers_.remove(topicPartitionName); LOG_DEBUG("Closing the consumer for partition - " << topicPartitionName << " numberTopicPartitions_ - " << numberTopicPartitions_->load()); @@ -543,15 +536,14 @@ void MultiTopicsConsumerImpl::acknowledgeAsync(const MessageId& msgId, ResultCal } const std::string& topicPartitionName = msgId.getTopicName(); - std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName); + auto optConsumer = consumers_.find(topicPartitionName); - if (consumers_.end() != iterator) { + if (optConsumer.is_present()) { unAckedMessageTrackerPtr_->remove(msgId); - iterator->second->acknowledgeAsync(msgId, callback); + optConsumer.value()->acknowledgeAsync(msgId, callback); } else { LOG_ERROR("Message of topic: " << topicPartitionName << " not in unAckedMessageTracker"); callback(ResultUnknownError); - return; } } @@ -560,11 +552,11 @@ void MultiTopicsConsumerImpl::acknowledgeCumulativeAsync(const MessageId& msgId, } void MultiTopicsConsumerImpl::negativeAcknowledge(const MessageId& msgId) { - auto iterator = consumers_.find(msgId.getTopicName()); + auto optConsumer = consumers_.find(msgId.getTopicName()); - if (consumers_.end() != iterator) { + if (optConsumer.is_present()) { unAckedMessageTrackerPtr_->remove(msgId); - iterator->second->negativeAcknowledge(msgId); + optConsumer.value()->negativeAcknowledge(msgId); } } @@ -605,22 +597,18 @@ bool MultiTopicsConsumerImpl::isOpen() { } void MultiTopicsConsumerImpl::receiveMessages() { - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - ConsumerImplPtr consumerPtr = consumer->second; - consumerPtr->sendFlowPermitsToBroker(consumerPtr->getCnx().lock(), conf_.getReceiverQueueSize()); - LOG_DEBUG("Sending FLOW command for consumer - " << consumerPtr->getConsumerId()); - } + const auto receiverQueueSize = conf_.getReceiverQueueSize(); + consumers_.forEachValue([receiverQueueSize](const ConsumerImplPtr& consumer) { + consumer->sendFlowPermitsToBroker(consumer->getCnx().lock(), receiverQueueSize); + LOG_DEBUG("Sending FLOW command for consumer - " << consumer->getConsumerId()); + }); } Result MultiTopicsConsumerImpl::pauseMessageListener() { if (!messageListener_) { return ResultInvalidConfiguration; } - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - (consumer->second)->pauseMessageListener(); - } + consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->pauseMessageListener(); }); return ResultOk; } @@ -628,19 +616,14 @@ Result MultiTopicsConsumerImpl::resumeMessageListener() { if (!messageListener_) { return ResultInvalidConfiguration; } - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - (consumer->second)->resumeMessageListener(); - } + consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->resumeMessageListener(); }); return ResultOk; } void MultiTopicsConsumerImpl::redeliverUnacknowledgedMessages() { LOG_DEBUG("Sending RedeliverUnacknowledgedMessages command for partitioned consumer."); - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - (consumer->second)->redeliverUnacknowledgedMessages(); - } + consumers_.forEachValue( + [](const ConsumerImplPtr& consumer) { consumer->redeliverUnacknowledgedMessages(); }); unAckedMessageTrackerPtr_->clear(); } @@ -653,10 +636,9 @@ void MultiTopicsConsumerImpl::redeliverUnacknowledgedMessages(const std::set<Mes return; } LOG_DEBUG("Sending RedeliverUnacknowledgedMessages command for partitioned consumer."); - for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end(); - consumer++) { - (consumer->second)->redeliverUnacknowledgedMessages(messageIds); - } + consumers_.forEachValue([&messageIds](const ConsumerImplPtr& consumer) { + consumer->redeliverUnacknowledgedMessages(messageIds); + }); } int MultiTopicsConsumerImpl::getNumOfPrefetchedMessages() const { return messages_.size(); } @@ -671,15 +653,17 @@ void MultiTopicsConsumerImpl::getBrokerConsumerStatsAsync(BrokerConsumerStatsCal MultiTopicsBrokerConsumerStatsPtr statsPtr = std::make_shared<MultiTopicsBrokerConsumerStatsImpl>(numberTopicPartitions_->load()); LatchPtr latchPtr = std::make_shared<Latch>(numberTopicPartitions_->load()); - int size = consumers_.size(); lock.unlock(); - ConsumerMap::const_iterator consumer = consumers_.begin(); - for (int i = 0; i < size; i++, consumer++) { - consumer->second->getBrokerConsumerStatsAsync( - std::bind(&MultiTopicsConsumerImpl::handleGetConsumerStats, shared_from_this(), - std::placeholders::_1, std::placeholders::_2, latchPtr, statsPtr, i, callback)); - } + auto self = shared_from_this(); + size_t i = 0; + consumers_.forEachValue([&self, &latchPtr, &statsPtr, &i, callback](const ConsumerImplPtr& consumer) { + size_t index = i++; + consumer->getBrokerConsumerStatsAsync( + [self, latchPtr, statsPtr, index, callback](Result result, BrokerConsumerStats stats) { + self->handleGetConsumerStats(result, stats, latchPtr, statsPtr, index, callback); + }); + }); } void MultiTopicsConsumerImpl::handleGetConsumerStats(Result res, BrokerConsumerStats brokerConsumerStats, @@ -725,10 +709,9 @@ void MultiTopicsConsumerImpl::seekAsync(uint64_t timestamp, ResultCallback callb } void MultiTopicsConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) { - Lock lock(mutex_); - for (auto&& c : consumers_) { - c.second->setNegativeAcknowledgeEnabledForTesting(enabled); - } + consumers_.forEachValue([enabled](const ConsumerImplPtr& consumer) { + consumer->setNegativeAcknowledgeEnabledForTesting(enabled); + }); } bool MultiTopicsConsumerImpl::isConnected() const { @@ -736,24 +719,19 @@ bool MultiTopicsConsumerImpl::isConnected() const { if (state_ != Ready) { return false; } + lock.unlock(); - for (const auto& topicAndConsumer : consumers_) { - if (!topicAndConsumer.second->isConnected()) { - return false; - } - } - return true; + return consumers_ + .findFirstValueIf([](const ConsumerImplPtr& consumer) { return !consumer->isConnected(); }) + .is_empty(); } uint64_t MultiTopicsConsumerImpl::getNumberOfConnectedConsumer() { - Lock lock(mutex_); uint64_t numberOfConnectedConsumer = 0; - const auto consumers = consumers_; - lock.unlock(); - for (const auto& topicAndConsumer : consumers) { - if (topicAndConsumer.second->isConnected()) { + consumers_.forEachValue([&numberOfConnectedConsumer](const ConsumerImplPtr& consumer) { + if (consumer->isConnected()) { numberOfConnectedConsumer++; } - } + }); return numberOfConnectedConsumer; } diff --git a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h index aa6b261..98b2f31 100644 --- a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h +++ b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h @@ -32,6 +32,7 @@ #include <lib/MultiTopicsBrokerConsumerStatsImpl.h> #include <lib/TopicName.h> #include <lib/NamespaceName.h> +#include <lib/SynchronizedHashMap.h> namespace pulsar { typedef std::shared_ptr<Promise<Result, Consumer>> ConsumerSubResultPromisePtr; @@ -93,7 +94,7 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase, std::string consumerStr_; std::string topic_; const ConsumerConfiguration conf_; - typedef std::map<std::string, ConsumerImplPtr> ConsumerMap; + typedef SynchronizedHashMap<std::string, ConsumerImplPtr> ConsumerMap; ConsumerMap consumers_; std::map<std::string, int> topicsPartitions_; mutable std::mutex mutex_; @@ -115,7 +116,7 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase, void handleSinglePartitionConsumerCreated(Result result, ConsumerImplBaseWeakPtr consumerImplBaseWeakPtr, unsigned int partitionIndex); - void handleSingleConsumerClose(Result result, std::string& topicPartitionName, CloseCallback callback); + void handleSingleConsumerClose(Result result, std::string topicPartitionName, CloseCallback callback); void notifyResult(CloseCallback closeCallback); void messageReceived(Consumer consumer, const Message& msg); void internalListener(Consumer consumer); diff --git a/pulsar-client-cpp/lib/SynchronizedHashMap.h b/pulsar-client-cpp/lib/SynchronizedHashMap.h new file mode 100644 index 0000000..3a78467 --- /dev/null +++ b/pulsar-client-cpp/lib/SynchronizedHashMap.h @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include <functional> +#include <mutex> +#include <unordered_map> +#include <utility> +#include <vector> +#include "Utils.h" + +namespace pulsar { + +// V must be default constructible and copyable +template <typename K, typename V> +class SynchronizedHashMap { + using MutexType = std::recursive_mutex; + using Lock = std::lock_guard<MutexType>; + + public: + using OptValue = Optional<V>; + using PairVector = std::vector<std::pair<K, V>>; + + SynchronizedHashMap() = default; + + SynchronizedHashMap(const PairVector& pairs) { + for (auto&& kv : pairs) { + data_.emplace(kv.first, kv.second); + } + } + + template <typename... Args> + void emplace(Args&&... args) { + Lock lock(mutex_); + data_.emplace(std::forward<Args>(args)...); + } + + void forEach(std::function<void(const K&, const V&)> f) const { + Lock lock(mutex_); + for (const auto& kv : data_) { + f(kv.first, kv.second); + } + } + + void forEachValue(std::function<void(const V&)> f) const { + Lock lock(mutex_); + for (const auto& kv : data_) { + f(kv.second); + } + } + + void clear() { + Lock lock(mutex_); + data_.clear(); + } + + OptValue find(const K& key) const { + Lock lock(mutex_); + auto it = data_.find(key); + if (it != data_.end()) { + return OptValue::of(it->second); + } else { + return OptValue::empty(); + } + } + + OptValue findFirstValueIf(std::function<bool(const V&)> f) const { + Lock lock(mutex_); + for (const auto& kv : data_) { + if (f(kv.second)) { + return OptValue::of(kv.second); + } + } + return OptValue::empty(); + } + + OptValue remove(const K& key) { + Lock lock(mutex_); + auto it = data_.find(key); + if (it != data_.end()) { + auto result = OptValue::of(it->second); + data_.erase(it); + return result; + } else { + return OptValue::empty(); + } + } + + // This method is only used for test + PairVector toPairVector() const { + Lock lock(mutex_); + PairVector pairs; + for (auto&& kv : data_) { + pairs.emplace_back(kv); + } + return pairs; + } + + // This method is only used for test + size_t size() const noexcept { + Lock lock(mutex_); + return data_.size(); + } + + private: + std::unordered_map<K, V> data_; + // Use recursive_mutex to allow methods being called in `forEach` + mutable MutexType mutex_; +}; + +} // namespace pulsar diff --git a/pulsar-client-cpp/tests/ConsumerTest.cc b/pulsar-client-cpp/tests/ConsumerTest.cc index 100086e..b61c15a 100644 --- a/pulsar-client-cpp/tests/ConsumerTest.cc +++ b/pulsar-client-cpp/tests/ConsumerTest.cc @@ -530,11 +530,14 @@ TEST(ConsumerTest, testMultiTopicsConsumerUnAckedMessageRedelivery) { multiTopicsConsumerImplPtr->unAckedMessageTrackerPtr_.get()); ASSERT_EQ(numOfMessages * 3, multiTopicsTracker->size()); ASSERT_FALSE(multiTopicsTracker->isEmpty()); - for (auto iter = multiTopicsConsumerImplPtr->consumers_.begin(); - iter != multiTopicsConsumerImplPtr->consumers_.end(); ++iter) { - auto subConsumerPtr = iter->second; - auto tracker = - static_cast<UnAckedMessageTrackerEnabled*>(subConsumerPtr->unAckedMessageTrackerPtr_.get()); + + std::vector<UnAckedMessageTrackerEnabled*> trackers; + multiTopicsConsumerImplPtr->consumers_.forEach( + [&trackers](const std::string& name, const ConsumerImplPtr& consumer) { + trackers.emplace_back( + static_cast<UnAckedMessageTrackerEnabled*>(consumer->unAckedMessageTrackerPtr_.get())); + }); + for (const auto& tracker : trackers) { ASSERT_EQ(0, tracker->size()); ASSERT_TRUE(tracker->isEmpty()); } diff --git a/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc b/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc new file mode 100644 index 0000000..62c55c4 --- /dev/null +++ b/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <algorithm> +#include <atomic> +#include <chrono> +#include <thread> +#include <vector> +#include "lib/Latch.h" +#include "lib/SynchronizedHashMap.h" + +using namespace pulsar; +using SyncMapType = SynchronizedHashMap<int, int>; +using OptValue = typename SyncMapType::OptValue; +using PairVector = typename SyncMapType::PairVector; + +inline void sleepMs(long millis) { std::this_thread::sleep_for(std::chrono::milliseconds(millis)); } + +inline PairVector sort(PairVector pairs) { + std::sort(pairs.begin(), pairs.end(), [](const std::pair<int, int>& lhs, const std::pair<int, int>& rhs) { + return lhs.first < rhs.first; + }); + return pairs; +} + +TEST(SynchronizedHashMap, testClear) { + SynchronizedHashMap<int, int> m({{1, 100}, {2, 200}}); + m.clear(); + ASSERT_EQ(m.toPairVector(), PairVector{}); +} + +TEST(SynchronizedHashMap, testRemoveAndFind) { + SyncMapType m({{1, 100}, {2, 200}, {3, 300}}); + + OptValue optValue; + optValue = m.findFirstValueIf([](const int& x) { return x == 200; }); + ASSERT_TRUE(optValue.is_present()); + ASSERT_EQ(optValue.value(), 200); + + optValue = m.findFirstValueIf([](const int& x) { return x >= 301; }); + ASSERT_FALSE(optValue.is_present()); + + optValue = m.find(1); + ASSERT_TRUE(optValue.is_present()); + ASSERT_EQ(optValue.value(), 100); + + ASSERT_FALSE(m.find(0).is_present()); + ASSERT_FALSE(m.remove(0).is_present()); + + optValue = m.remove(1); + ASSERT_TRUE(optValue.is_present()); + ASSERT_EQ(optValue.value(), 100); + + ASSERT_FALSE(m.remove(1).is_present()); + ASSERT_FALSE(m.find(1).is_present()); +} + +TEST(SynchronizedHashMapTest, testForEach) { + SyncMapType m({{1, 100}, {2, 200}, {3, 300}}); + std::vector<int> values; + m.forEachValue([&values](const int& value) { values.emplace_back(value); }); + std::sort(values.begin(), values.end()); + ASSERT_EQ(values, std::vector<int>({100, 200, 300})); + + PairVector pairs; + m.forEach([&pairs](const int& key, const int& value) { pairs.emplace_back(key, value); }); + PairVector expectedPairs({{1, 100}, {2, 200}, {3, 300}}); + ASSERT_EQ(sort(pairs), expectedPairs); +} + +TEST(SynchronizedHashMap, testRecursiveMutex) { + SyncMapType m({{1, 100}}); + OptValue optValue; + m.forEach([&m, &optValue](const int& key, const int& value) { + optValue = m.find(key); // the internal mutex was locked again + }); + ASSERT_TRUE(optValue.is_present()); + ASSERT_EQ(optValue.value(), 100); +} + +TEST(SynchronizedHashMapTest, testThreadSafeForEach) { + SyncMapType m({{1, 100}, {2, 200}, {3, 300}}); + + Latch latch(1); + std::thread t{[&m, &latch] { + latch.wait(); // this thread must start after `m.forEach` started + m.remove(2); + }}; + + std::atomic_bool firstElementDone{false}; + PairVector pairs; + m.forEach([&latch, &firstElementDone, &pairs](const int& key, const int& value) { + pairs.emplace_back(key, value); + if (!firstElementDone) { + latch.countdown(); + firstElementDone = true; + } + sleepMs(200); + }); + { + PairVector expectedPairs({{1, 100}, {2, 200}, {3, 300}}); + ASSERT_EQ(sort(pairs), expectedPairs); + } + t.join(); + { + PairVector expectedPairs({{1, 100}, {3, 300}}); + ASSERT_EQ(sort(m.toPairVector()), expectedPairs); + } +}
