This is an automated email from the ASF dual-hosted git repository.

xyz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git

commit 6cb391abd4d4f84f3e58f94722290791269a3b0e
Author: Yunze Xu <[email protected]>
AuthorDate: Wed Dec 6 10:37:35 2023 +0800

    Fix accessing destroyed objects in the callback of async_wait (#362)
    
    Fixes https://github.com/apache/pulsar-client-cpp/issues/358
    Fixes https://github.com/apache/pulsar-client-cpp/issues/359
    
    ### Motivation
    
    `async_wait` is not used correctly in some places. A callback that
    captures the `this` pointer or reference to `this` is passed to
    `async_wait`, if this object is destroyed when the callback is called,
    an invalid memory access will happen.
    
    ### Modifications
    
    Use the following pattern in all `async_wait` calls.
    
    ```c++
    std::weak_ptr<T> weakSelf{shared_from_this()};
    timer_->async_wait([weakSelf](/* ... */) {
        if (auto self = weakSelf.lock()) {
            self->foo();
        }
    });
    ```
    
    (cherry picked from commit 24ab12c74127276c0e264ce869e553749b9170e8)
---
 lib/ConsumerImpl.cc                   | 12 +++++++-----
 lib/ConsumerImpl.h                    |  2 +-
 lib/MultiTopicsConsumerImpl.cc        |  1 +
 lib/NegativeAcksTracker.cc            |  7 ++++++-
 lib/NegativeAcksTracker.h             |  2 +-
 lib/PatternMultiTopicsConsumerImpl.cc | 17 +++++++++++++----
 lib/PatternMultiTopicsConsumerImpl.h  |  4 ++++
 lib/UnAckedMessageTrackerEnabled.cc   | 19 ++++++++++---------
 lib/UnAckedMessageTrackerEnabled.h    | 19 +++++++++++--------
 lib/UnAckedMessageTrackerInterface.h  |  2 ++
 tests/BasicEndToEndTest.cc            |  2 ++
 tests/ConsumerTest.cc                 |  4 +++-
 12 files changed, 61 insertions(+), 30 deletions(-)

diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc
index b466683..18c12ed 100644
--- a/lib/ConsumerImpl.cc
+++ b/lib/ConsumerImpl.cc
@@ -86,7 +86,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const 
std::string& topic,
       consumerName_(config_.getConsumerName()),
       consumerStr_("[" + topic + ", " + subscriptionName + ", " + 
std::to_string(consumerId_) + "] "),
       messageListenerRunning_(true),
-      negativeAcksTracker_(client, *this, conf),
+      negativeAcksTracker_(std::make_shared<NegativeAcksTracker>(client, 
*this, conf)),
       readCompacted_(conf.isReadCompacted()),
       startMessageId_(startMessageId),
       maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()),
@@ -105,6 +105,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, 
const std::string& topic,
     } else {
         unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled());
     }
+    unAckedMessageTrackerPtr_->start();
 
     // Setup stats reporter.
     unsigned int statsIntervalInSeconds = 
client->getClientConfig().getStatsIntervalInSeconds();
@@ -1228,7 +1229,7 @@ std::pair<MessageId, bool> 
ConsumerImpl::prepareCumulativeAck(const MessageId& m
 
 void ConsumerImpl::negativeAcknowledge(const MessageId& messageId) {
     unAckedMessageTrackerPtr_->remove(messageId);
-    negativeAcksTracker_.add(messageId);
+    negativeAcksTracker_->add(messageId);
 }
 
 void ConsumerImpl::disconnectConsumer() {
@@ -1266,7 +1267,7 @@ void ConsumerImpl::closeAsync(ResultCallback 
originalCallback) {
     if (ackGroupingTrackerPtr_) {
         ackGroupingTrackerPtr_->close();
     }
-    negativeAcksTracker_.close();
+    negativeAcksTracker_->close();
 
     ClientConnectionPtr cnx = getCnx().lock();
     if (!cnx) {
@@ -1304,7 +1305,7 @@ void ConsumerImpl::shutdown() {
     if (client) {
         client->cleanupConsumer(this);
     }
-    negativeAcksTracker_.close();
+    negativeAcksTracker_->close();
     cancelTimers();
     consumerCreatedPromise_.setFailed(ResultAlreadyClosed);
     failPendingReceiveCallback();
@@ -1609,7 +1610,7 @@ void ConsumerImpl::internalGetLastMessageIdAsync(const 
BackoffPtr& backoff, Time
 }
 
 void ConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) {
-    negativeAcksTracker_.setEnabledForTesting(enabled);
+    negativeAcksTracker_->setEnabledForTesting(enabled);
 }
 
 void ConsumerImpl::trackMessage(const MessageId& messageId) {
@@ -1696,6 +1697,7 @@ void ConsumerImpl::cancelTimers() noexcept {
     boost::system::error_code ec;
     batchReceiveTimer_->cancel(ec);
     checkExpiredChunkedTimer_->cancel(ec);
+    unAckedMessageTrackerPtr_->stop();
 }
 
 void ConsumerImpl::processPossibleToDLQ(const MessageId& messageId, 
ProcessDLQCallBack cb) {
diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h
index 61d96b1..3243709 100644
--- a/lib/ConsumerImpl.h
+++ b/lib/ConsumerImpl.h
@@ -224,7 +224,7 @@ class ConsumerImpl : public ConsumerImplBase {
     CompressionCodecProvider compressionCodecProvider_;
     UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_;
     BrokerConsumerStatsImpl brokerConsumerStats_;
-    NegativeAcksTracker negativeAcksTracker_;
+    std::shared_ptr<NegativeAcksTracker> negativeAcksTracker_;
     AckGroupingTrackerPtr ackGroupingTrackerPtr_;
 
     MessageCryptoPtr msgCrypto_;
diff --git a/lib/MultiTopicsConsumerImpl.cc b/lib/MultiTopicsConsumerImpl.cc
index abc54c8..15f9d9b 100644
--- a/lib/MultiTopicsConsumerImpl.cc
+++ b/lib/MultiTopicsConsumerImpl.cc
@@ -86,6 +86,7 @@ 
MultiTopicsConsumerImpl::MultiTopicsConsumerImpl(ClientImplPtr client, const std
     } else {
         unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled());
     }
+    unAckedMessageTrackerPtr_->start();
     auto partitionsUpdateInterval = static_cast<unsigned 
int>(client->conf().getPartitionsUpdateInterval());
     if (partitionsUpdateInterval > 0) {
         partitionsUpdateTimer_ = listenerExecutor_->createDeadlineTimer();
diff --git a/lib/NegativeAcksTracker.cc b/lib/NegativeAcksTracker.cc
index 5c3ef3f..0dd7358 100644
--- a/lib/NegativeAcksTracker.cc
+++ b/lib/NegativeAcksTracker.cc
@@ -49,8 +49,13 @@ void NegativeAcksTracker::scheduleTimer() {
     if (closed_) {
         return;
     }
+    std::weak_ptr<NegativeAcksTracker> weakSelf{shared_from_this()};
     timer_->expires_from_now(timerInterval_);
-    timer_->async_wait(std::bind(&NegativeAcksTracker::handleTimer, this, 
std::placeholders::_1));
+    timer_->async_wait([weakSelf](const boost::system::error_code &ec) {
+        if (auto self = weakSelf.lock()) {
+            self->handleTimer(ec);
+        }
+    });
 }
 
 void NegativeAcksTracker::handleTimer(const boost::system::error_code &ec) {
diff --git a/lib/NegativeAcksTracker.h b/lib/NegativeAcksTracker.h
index 029f7d2..4b48984 100644
--- a/lib/NegativeAcksTracker.h
+++ b/lib/NegativeAcksTracker.h
@@ -40,7 +40,7 @@ using DeadlineTimerPtr = 
std::shared_ptr<boost::asio::deadline_timer>;
 class ExecutorService;
 using ExecutorServicePtr = std::shared_ptr<ExecutorService>;
 
-class NegativeAcksTracker {
+class NegativeAcksTracker : public 
std::enable_shared_from_this<NegativeAcksTracker> {
    public:
     NegativeAcksTracker(ClientImplPtr client, ConsumerImpl &consumer, const 
ConsumerConfiguration &conf);
 
diff --git a/lib/PatternMultiTopicsConsumerImpl.cc 
b/lib/PatternMultiTopicsConsumerImpl.cc
index e100a1c..23e445e 100644
--- a/lib/PatternMultiTopicsConsumerImpl.cc
+++ b/lib/PatternMultiTopicsConsumerImpl.cc
@@ -47,8 +47,13 @@ const PULSAR_REGEX_NAMESPACE::regex 
PatternMultiTopicsConsumerImpl::getPattern()
 void PatternMultiTopicsConsumerImpl::resetAutoDiscoveryTimer() {
     autoDiscoveryRunning_ = false;
     
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
-    autoDiscoveryTimer_->async_wait(
-        std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, 
this, std::placeholders::_1));
+
+    auto weakSelf = weak_from_this();
+    autoDiscoveryTimer_->async_wait([weakSelf](const 
boost::system::error_code& err) {
+        if (auto self = weakSelf.lock()) {
+            self->autoDiscoveryTimerTask(err);
+        }
+    });
 }
 
 void PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask(const 
boost::system::error_code& err) {
@@ -222,8 +227,12 @@ void PatternMultiTopicsConsumerImpl::start() {
 
     if (conf_.getPatternAutoDiscoveryPeriod() > 0) {
         
autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod()));
-        autoDiscoveryTimer_->async_wait(
-            std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, 
this, std::placeholders::_1));
+        auto weakSelf = weak_from_this();
+        autoDiscoveryTimer_->async_wait([weakSelf](const 
boost::system::error_code& err) {
+            if (auto self = weakSelf.lock()) {
+                self->autoDiscoveryTimerTask(err);
+            }
+        });
     }
 }
 
diff --git a/lib/PatternMultiTopicsConsumerImpl.h 
b/lib/PatternMultiTopicsConsumerImpl.h
index f13750a..5d3ba9e 100644
--- a/lib/PatternMultiTopicsConsumerImpl.h
+++ b/lib/PatternMultiTopicsConsumerImpl.h
@@ -86,6 +86,10 @@ class PatternMultiTopicsConsumerImpl : public 
MultiTopicsConsumerImpl {
     void onTopicsRemoved(NamespaceTopicsPtr removedTopics, ResultCallback 
callback);
     void handleOneTopicAdded(const Result result, const std::string& topic,
                              std::shared_ptr<std::atomic<int>> 
topicsNeedCreate, ResultCallback callback);
+
+    std::weak_ptr<PatternMultiTopicsConsumerImpl> weak_from_this() noexcept {
+        return 
std::static_pointer_cast<PatternMultiTopicsConsumerImpl>(shared_from_this());
+    }
 };
 
 }  // namespace pulsar
diff --git a/lib/UnAckedMessageTrackerEnabled.cc 
b/lib/UnAckedMessageTrackerEnabled.cc
index ff1b928..061a140 100644
--- a/lib/UnAckedMessageTrackerEnabled.cc
+++ b/lib/UnAckedMessageTrackerEnabled.cc
@@ -35,11 +35,11 @@ void UnAckedMessageTrackerEnabled::timeoutHandler() {
     ExecutorServicePtr executorService = 
client_->getIOExecutorProvider()->get();
     timer_ = executorService->createDeadlineTimer();
     
timer_->expires_from_now(boost::posix_time::milliseconds(tickDurationInMs_));
-    timer_->async_wait([&](const boost::system::error_code& ec) {
-        if (ec) {
-            LOG_DEBUG("Ignoring timer cancelled event, code[" << ec << "]");
-        } else {
-            timeoutHandler();
+    std::weak_ptr<UnAckedMessageTrackerEnabled> weakSelf{shared_from_this()};
+    timer_->async_wait([weakSelf](const boost::system::error_code& ec) {
+        auto self = weakSelf.lock();
+        if (self && !ec) {
+            self->timeoutHandler();
         }
     });
 }
@@ -91,10 +91,10 @@ 
UnAckedMessageTrackerEnabled::UnAckedMessageTrackerEnabled(long timeoutMs, long
         std::set<MessageId> msgIds;
         timePartitions.push_back(msgIds);
     }
-
-    timeoutHandler();
 }
 
+void UnAckedMessageTrackerEnabled::start() { timeoutHandler(); }
+
 bool UnAckedMessageTrackerEnabled::add(const MessageId& msgId) {
     std::lock_guard<std::recursive_mutex> acquire(lock_);
     auto id = discardBatch(msgId);
@@ -172,9 +172,10 @@ void UnAckedMessageTrackerEnabled::clear() {
     }
 }
 
-UnAckedMessageTrackerEnabled::~UnAckedMessageTrackerEnabled() {
+void UnAckedMessageTrackerEnabled::stop() {
+    boost::system::error_code ec;
     if (timer_) {
-        timer_->cancel();
+        timer_->cancel(ec);
     }
 }
 } /* namespace pulsar */
diff --git a/lib/UnAckedMessageTrackerEnabled.h 
b/lib/UnAckedMessageTrackerEnabled.h
index 1453460..6181a8a 100644
--- a/lib/UnAckedMessageTrackerEnabled.h
+++ b/lib/UnAckedMessageTrackerEnabled.h
@@ -21,6 +21,7 @@
 #include <boost/asio/deadline_timer.hpp>
 #include <deque>
 #include <map>
+#include <memory>
 #include <mutex>
 #include <set>
 
@@ -34,19 +35,21 @@ class ConsumerImplBase;
 using ClientImplPtr = std::shared_ptr<ClientImpl>;
 using DeadlineTimerPtr = std::shared_ptr<boost::asio::deadline_timer>;
 
-class UnAckedMessageTrackerEnabled : public UnAckedMessageTrackerInterface {
+class UnAckedMessageTrackerEnabled : public 
std::enable_shared_from_this<UnAckedMessageTrackerEnabled>,
+                                     public UnAckedMessageTrackerInterface {
    public:
-    ~UnAckedMessageTrackerEnabled();
     UnAckedMessageTrackerEnabled(long timeoutMs, ClientImplPtr, 
ConsumerImplBase&);
     UnAckedMessageTrackerEnabled(long timeoutMs, long tickDuration, 
ClientImplPtr, ConsumerImplBase&);
-    bool add(const MessageId& msgId);
-    bool remove(const MessageId& msgId);
-    void remove(const MessageIdList& msgIds);
-    void removeMessagesTill(const MessageId& msgId);
-    void removeTopicMessage(const std::string& topic);
+    void start() override;
+    void stop() override;
+    bool add(const MessageId& msgId) override;
+    bool remove(const MessageId& msgId) override;
+    void remove(const MessageIdList& msgIds) override;
+    void removeMessagesTill(const MessageId& msgId) override;
+    void removeTopicMessage(const std::string& topic) override;
     void timeoutHandler();
 
-    void clear();
+    void clear() override;
 
    protected:
     void timeoutHandlerHelper();
diff --git a/lib/UnAckedMessageTrackerInterface.h 
b/lib/UnAckedMessageTrackerInterface.h
index d1fe789..4df8819 100644
--- a/lib/UnAckedMessageTrackerInterface.h
+++ b/lib/UnAckedMessageTrackerInterface.h
@@ -28,6 +28,8 @@ class UnAckedMessageTrackerInterface {
    public:
     virtual ~UnAckedMessageTrackerInterface() {}
     UnAckedMessageTrackerInterface() {}
+    virtual void start() {}
+    virtual void stop() {}
     virtual bool add(const MessageId& m) = 0;
     virtual bool remove(const MessageId& m) = 0;
     virtual void remove(const MessageIdList& msgIds) = 0;
diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc
index e2c6697..5dbccbf 100644
--- a/tests/BasicEndToEndTest.cc
+++ b/tests/BasicEndToEndTest.cc
@@ -3973,6 +3973,7 @@ TEST(BasicEndToEndTest, 
testUnAckedMessageTrackerEnabledIndividualAck) {
 
     auto tracker0 = 
std::make_shared<UnAckedMessageTrackerEnabledMock>(unAckedMessagesTimeoutMs,
                                                                        
clientImplPtr, consumerImpl0);
+    tracker0->start();
     ASSERT_EQ(tracker0->getUnAckedMessagesTimeoutMs(), 
unAckedMessagesTimeoutMs);
     ASSERT_EQ(tracker0->getTickDurationInMs(), unAckedMessagesTimeoutMs);
 
@@ -4048,6 +4049,7 @@ TEST(BasicEndToEndTest, 
testUnAckedMessageTrackerEnabledCumulativeAck) {
     }
     auto tracker = 
std::make_shared<UnAckedMessageTrackerEnabledMock>(unAckedMessagesTimeoutMs, 
clientImplPtr,
                                                                       
consumerImpl0);
+    tracker->start();
     for (auto idx = 0; idx < numMsg; ++idx) {
         ASSERT_TRUE(tracker->add(recvMsgId[idx]));
     }
diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc
index 0836fbf..fc745f3 100644
--- a/tests/ConsumerTest.cc
+++ b/tests/ConsumerTest.cc
@@ -993,6 +993,7 @@ TEST(ConsumerTest, 
testRedeliveryOfDecryptionFailedMessages) {
     auto consumer2ImplPtr = PulsarFriend::getConsumerImplPtr(consumer2);
     consumer2ImplPtr->unAckedMessageTrackerPtr_.reset(new 
UnAckedMessageTrackerEnabled(
         100, 100, PulsarFriend::getClientImplPtr(client), 
static_cast<ConsumerImplBase&>(*consumer2ImplPtr)));
+    consumer2ImplPtr->unAckedMessageTrackerPtr_->start();
 
     ConsumerConfiguration consConfig3;
     consConfig3.setConsumerType(pulsar::ConsumerShared);
@@ -1003,6 +1004,7 @@ TEST(ConsumerTest, 
testRedeliveryOfDecryptionFailedMessages) {
     auto consumer3ImplPtr = PulsarFriend::getConsumerImplPtr(consumer3);
     consumer3ImplPtr->unAckedMessageTrackerPtr_.reset(new 
UnAckedMessageTrackerEnabled(
         100, 100, PulsarFriend::getClientImplPtr(client), 
static_cast<ConsumerImplBase&>(*consumer3ImplPtr)));
+    consumer3ImplPtr->unAckedMessageTrackerPtr_->start();
 
     int numberOfMessages = 20;
     std::string msgContent = "msg-content";
@@ -1222,7 +1224,7 @@ TEST(ConsumerTest, testNegativeAcksTrackerClose) {
 
     consumer.close();
     auto consumerImplPtr = PulsarFriend::getConsumerImplPtr(consumer);
-    ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_.nackedMessages_.empty());
+    
ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_->nackedMessages_.empty());
 
     client.close();
 }

Reply via email to