This is an automated email from the ASF dual-hosted git repository.
xyz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git
The following commit(s) were added to refs/heads/main by this push:
new 20f6fa0 Fix the buggy Future and Promise implementations (#299)
20f6fa0 is described below
commit 20f6fa0a72929f8e2668f45790f44e6a00390a44
Author: Yunze Xu <[email protected]>
AuthorDate: Wed Jul 5 14:49:31 2023 +0800
Fix the buggy Future and Promise implementations (#299)
Fixes https://github.com/apache/pulsar-client-cpp/issues/298
### Motivation
Currently the `Future` and `Promise` are implemented manually by
managing conditional variables. However, the conditional variable
sometimes behaviors incorrectly on macOS, while the existing `future`
and `promise` from the C++ standard library works well.
### Modifications
Redesign `Future` and `Promise` based on the utilities in the standard
`<future>` header. In addition, fix the possible race condition when
`addListener` is called after `setValue` or `setFailed`:
- Thread 1: call `setValue`, switch existing listeners and call them one
by one out of the lock.
- Thread 2: call `addListener`, detect `complete_` is true and call the
listener directly.
Now, the previous listeners and the new listener are called concurrently
in thread 1 and 2.
This patch fixes the problem by adding a future to wait all listeners
that were added before completing are done.
### Verifications
Run the reproduce code in #298 for 10 times and found it never failed or
hang.
Co-authored-by: Zike Yang <[email protected]>
---------
Co-authored-by: Zike Yang <[email protected]>
---
lib/BinaryProtoLookupService.cc | 2 +-
lib/Future.h | 201 +++++++++++++++++-----------------------
lib/stats/ProducerStatsImpl.cc | 6 +-
tests/BasicEndToEndTest.cc | 4 +-
tests/PromiseTest.cc | 27 ++++++
5 files changed, 119 insertions(+), 121 deletions(-)
diff --git a/lib/BinaryProtoLookupService.cc b/lib/BinaryProtoLookupService.cc
index f563f63..dfa3cab 100644
--- a/lib/BinaryProtoLookupService.cc
+++ b/lib/BinaryProtoLookupService.cc
@@ -146,7 +146,7 @@ void
BinaryProtoLookupService::handlePartitionMetadataLookup(const std::string&
}
uint64_t BinaryProtoLookupService::newRequestId() {
- Lock lock(mutex_);
+ std::lock_guard<std::mutex> lock(mutex_);
return ++requestIdGenerator_;
}
diff --git a/lib/Future.h b/lib/Future.h
index 3593057..03e93e4 100644
--- a/lib/Future.h
+++ b/lib/Future.h
@@ -19,162 +19,133 @@
#ifndef LIB_FUTURE_H_
#define LIB_FUTURE_H_
-#include <condition_variable>
+#include <atomic>
+#include <chrono>
#include <functional>
+#include <future>
#include <list>
#include <memory>
#include <mutex>
-
-using Lock = std::unique_lock<std::mutex>;
+#include <thread>
+#include <utility>
namespace pulsar {
template <typename Result, typename Type>
-struct InternalState {
- std::mutex mutex;
- std::condition_variable condition;
- Result result;
- Type value;
- bool complete;
-
- std::list<typename std::function<void(Result, const Type&)> > listeners;
-};
-
-template <typename Result, typename Type>
-class Future {
+class InternalState {
public:
- typedef std::function<void(Result, const Type&)> ListenerCallback;
-
- Future& addListener(ListenerCallback callback) {
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
-
- if (state->complete) {
- lock.unlock();
- callback(state->result, state->value);
- } else {
- state->listeners.push_back(callback);
- }
+ using Listener = std::function<void(Result, const Type &)>;
+ using Pair = std::pair<Result, Type>;
+ using Lock = std::unique_lock<std::mutex>;
- return *this;
- }
+ // NOTE: Add the constructor explicitly just to be compatible with GCC 4.8
+ InternalState() {}
- Result get(Type& result) {
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
+ void addListener(Listener listener) {
+ Lock lock{mutex_};
+ listeners_.emplace_back(listener);
+ lock.unlock();
- if (!state->complete) {
- // Wait for result
- while (!state->complete) {
- state->condition.wait(lock);
- }
+ if (completed()) {
+ Type value;
+ Result result = get(value);
+ triggerListeners(result, value);
}
-
- result = state->value;
- return state->result;
}
- template <typename Duration>
- bool get(Result& res, Type& value, Duration d) {
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
-
- if (!state->complete) {
- // Wait for result
- while (!state->complete) {
- if (!state->condition.wait_for(lock, d, [&state] { return
state->complete; })) {
- // Timeout while waiting for the future to complete
- return false;
- }
- }
+ bool complete(Result result, const Type &value) {
+ bool expected = false;
+ if (!completed_.compare_exchange_strong(expected, true)) {
+ return false;
}
-
- value = state->value;
- res = state->result;
+ triggerListeners(result, value);
+ promise_.set_value(std::make_pair(result, value));
return true;
}
- private:
- typedef std::shared_ptr<InternalState<Result, Type> > InternalStatePtr;
- Future(InternalStatePtr state) : state_(state) {}
+ bool completed() const noexcept { return completed_; }
- std::shared_ptr<InternalState<Result, Type> > state_;
-
- template <typename U, typename V>
- friend class Promise;
-};
+ Result get(Type &result) {
+ const auto &pair = future_.get();
+ result = pair.second;
+ return pair.first;
+ }
-template <typename Result, typename Type>
-class Promise {
- public:
- Promise() : state_(std::make_shared<InternalState<Result, Type> >()) {}
+ // Only public for test
+ void triggerListeners(Result result, const Type &value) {
+ while (true) {
+ Lock lock{mutex_};
+ if (listeners_.empty()) {
+ return;
+ }
- bool setValue(const Type& value) const {
- static Result DEFAULT_RESULT;
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
+ bool expected = false;
+ if (!listenerRunning_.compare_exchange_strong(expected, true)) {
+ // There is another thread that polled a listener that is
running, skip polling and release
+ // the lock. Here we wait for some time to avoid busy waiting.
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ continue;
+ }
+ auto listener = std::move(listeners_.front());
+ listeners_.pop_front();
+ lock.unlock();
- if (state->complete) {
- return false;
+ listener(result, value);
+ listenerRunning_ = false;
}
+ }
- state->value = value;
- state->result = DEFAULT_RESULT;
- state->complete = true;
+ private:
+ std::atomic_bool completed_{false};
+ std::promise<Pair> promise_;
+ std::shared_future<Pair> future_{promise_.get_future()};
- decltype(state->listeners) listeners;
- listeners.swap(state->listeners);
+ std::list<Listener> listeners_;
+ mutable std::mutex mutex_;
+ std::atomic_bool listenerRunning_{false};
+};
- lock.unlock();
+template <typename Result, typename Type>
+using InternalStatePtr = std::shared_ptr<InternalState<Result, Type>>;
- for (auto& callback : listeners) {
- callback(DEFAULT_RESULT, value);
- }
+template <typename Result, typename Type>
+class Future {
+ public:
+ using Listener = typename InternalState<Result, Type>::Listener;
- state->condition.notify_all();
- return true;
+ Future &addListener(Listener listener) {
+ state_->addListener(listener);
+ return *this;
}
- bool setFailed(Result result) const {
- static Type DEFAULT_VALUE;
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
+ Result get(Type &result) { return state_->get(result); }
- if (state->complete) {
- return false;
- }
+ private:
+ InternalStatePtr<Result, Type> state_;
- state->result = result;
- state->complete = true;
+ Future(InternalStatePtr<Result, Type> state) : state_(state) {}
- decltype(state->listeners) listeners;
- listeners.swap(state->listeners);
+ template <typename U, typename V>
+ friend class Promise;
+};
- lock.unlock();
+template <typename Result, typename Type>
+class Promise {
+ public:
+ Promise() : state_(std::make_shared<InternalState<Result, Type>>()) {}
- for (auto& callback : listeners) {
- callback(result, DEFAULT_VALUE);
- }
+ bool setValue(const Type &value) const { return state_->complete({},
value); }
- state->condition.notify_all();
- return true;
- }
+ bool setFailed(Result result) const { return state_->complete(result, {});
}
- bool isComplete() const {
- InternalState<Result, Type>* state = state_.get();
- Lock lock(state->mutex);
- return state->complete;
- }
+ bool isComplete() const { return state_->completed(); }
- Future<Result, Type> getFuture() const { return Future<Result,
Type>(state_); }
+ Future<Result, Type> getFuture() const { return Future<Result,
Type>{state_}; }
private:
- typedef std::function<void(Result, const Type&)> ListenerCallback;
- std::shared_ptr<InternalState<Result, Type> > state_;
+ const InternalStatePtr<Result, Type> state_;
};
-class Void {};
-
-} /* namespace pulsar */
+} // namespace pulsar
-#endif /* LIB_FUTURE_H_ */
+#endif
diff --git a/lib/stats/ProducerStatsImpl.cc b/lib/stats/ProducerStatsImpl.cc
index 9b0f7e6..3d3629d 100644
--- a/lib/stats/ProducerStatsImpl.cc
+++ b/lib/stats/ProducerStatsImpl.cc
@@ -71,7 +71,7 @@ void ProducerStatsImpl::flushAndReset(const
boost::system::error_code& ec) {
return;
}
- Lock lock(mutex_);
+ std::unique_lock<std::mutex> lock(mutex_);
std::ostringstream oss;
oss << *this;
numMsgsSent_ = 0;
@@ -86,7 +86,7 @@ void ProducerStatsImpl::flushAndReset(const
boost::system::error_code& ec) {
}
void ProducerStatsImpl::messageSent(const Message& msg) {
- Lock lock(mutex_);
+ std::lock_guard<std::mutex> lock(mutex_);
numMsgsSent_++;
totalMsgsSent_++;
numBytesSent_ += msg.getLength();
@@ -96,7 +96,7 @@ void ProducerStatsImpl::messageSent(const Message& msg) {
void ProducerStatsImpl::messageReceived(Result res, const
boost::posix_time::ptime& publishTime) {
boost::posix_time::ptime currentTime =
boost::posix_time::microsec_clock::universal_time();
double diffInMicros = (currentTime - publishTime).total_microseconds();
- Lock lock(mutex_);
+ std::lock_guard<std::mutex> lock(mutex_);
totalLatencyAccumulator_(diffInMicros);
latencyAccumulator_(diffInMicros);
sendMap_[res] += 1; // Value will automatically be initialized to 0
in the constructor
diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc
index 8599b92..9ca2ab0 100644
--- a/tests/BasicEndToEndTest.cc
+++ b/tests/BasicEndToEndTest.cc
@@ -191,7 +191,7 @@ TEST(BasicEndToEndTest, testBatchMessages) {
}
void resendMessage(Result r, const MessageId msgId, Producer producer) {
- Lock lock(mutex_);
+ std::unique_lock<std::mutex> lock(mutex_);
if (r != ResultOk) {
LOG_DEBUG("globalResendMessageCount" << globalResendMessageCount);
if (++globalResendMessageCount >= 3) {
@@ -993,7 +993,7 @@ TEST(BasicEndToEndTest, testResendViaSendCallback) {
// 3 seconds
std::this_thread::sleep_for(std::chrono::microseconds(3 * 1000 * 1000));
producer.close();
- Lock lock(mutex_);
+ std::lock_guard<std::mutex> lock(mutex_);
ASSERT_GE(globalResendMessageCount, 3);
}
diff --git a/tests/PromiseTest.cc b/tests/PromiseTest.cc
index 25b6b72..29ee2a3 100644
--- a/tests/PromiseTest.cc
+++ b/tests/PromiseTest.cc
@@ -24,6 +24,9 @@
#include <vector>
#include "lib/Future.h"
+#include "lib/LogUtils.h"
+
+DECLARE_LOG_OBJECT()
using namespace pulsar;
@@ -84,3 +87,27 @@ TEST(PromiseTest, testListeners) {
ASSERT_EQ(results, (std::vector<int>(2, 0)));
ASSERT_EQ(values, (std::vector<std::string>(2, "hello")));
}
+
+TEST(PromiseTest, testTriggerListeners) {
+ InternalState<int, int> state;
+ state.addListener([](int, const int&) {
+ LOG_INFO("Start task 1...");
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ LOG_INFO("Finish task 1...");
+ });
+ state.addListener([](int, const int&) {
+ LOG_INFO("Start task 2...");
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ LOG_INFO("Finish task 2...");
+ });
+
+ auto start = std::chrono::high_resolution_clock::now();
+ auto future1 = std::async(std::launch::async, [&state] {
state.triggerListeners(0, 0); });
+ auto future2 = std::async(std::launch::async, [&state] {
state.triggerListeners(0, 0); });
+ future1.wait();
+ future2.wait();
+ auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::high_resolution_clock::now() - start)
+ .count();
+ ASSERT_TRUE(elapsed > 2000) << "elapsed: " << elapsed << "ms";
+}