From 0729a9ea2846e854cb4f5b71250a789319909c0c Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Mon, 4 Dec 2023 20:56:52 +0800 Subject: [PATCH] Fix accessing destroyed objects in the callback of async_wait Fixes https://github.com/apache/pulsar-client-cpp/issues/358 Fixes https://github.com/apache/pulsar-client-cpp/issues/359 ### Motivation `async_wait` is not used correctly in some places. A callback that captures the `this` pointer or reference to `this` is passed to `async_wait`, if this object is destroyed when the callback is called, an invalid memory access will happen. ### Modifications Use the following pattern in all `async_wait` calls. ```c++ std::weak_ptr weakSelf{shared_from_this()}; timer_->async_wait([weakSelf](/* ... */) { if (auto self = weakSelf.lock()) { self->foo(); } }); ``` --- lib/ConsumerImpl.cc | 12 +++++++----- lib/ConsumerImpl.h | 2 +- lib/NegativeAcksTracker.cc | 7 ++++++- lib/NegativeAcksTracker.h | 2 +- lib/PatternMultiTopicsConsumerImpl.cc | 17 +++++++++++++---- lib/PatternMultiTopicsConsumerImpl.h | 4 ++++ lib/UnAckedMessageTrackerEnabled.cc | 19 ++++++++++--------- lib/UnAckedMessageTrackerEnabled.h | 19 +++++++++++-------- lib/UnAckedMessageTrackerInterface.h | 2 ++ tests/ConsumerTest.cc | 2 +- 10 files changed, 56 insertions(+), 30 deletions(-) diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index b4666836..18c12ed5 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -86,7 +86,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic, consumerName_(config_.getConsumerName()), consumerStr_("[" + topic + ", " + subscriptionName + ", " + std::to_string(consumerId_) + "] "), messageListenerRunning_(true), - negativeAcksTracker_(client, *this, conf), + negativeAcksTracker_(std::make_shared(client, *this, conf)), readCompacted_(conf.isReadCompacted()), startMessageId_(startMessageId), maxPendingChunkedMessage_(conf.getMaxPendingChunkedMessage()), @@ -105,6 +105,7 @@ ConsumerImpl::ConsumerImpl(const ClientImplPtr client, const std::string& topic, } else { unAckedMessageTrackerPtr_.reset(new UnAckedMessageTrackerDisabled()); } + unAckedMessageTrackerPtr_->start(); // Setup stats reporter. unsigned int statsIntervalInSeconds = client->getClientConfig().getStatsIntervalInSeconds(); @@ -1228,7 +1229,7 @@ std::pair ConsumerImpl::prepareCumulativeAck(const MessageId& m void ConsumerImpl::negativeAcknowledge(const MessageId& messageId) { unAckedMessageTrackerPtr_->remove(messageId); - negativeAcksTracker_.add(messageId); + negativeAcksTracker_->add(messageId); } void ConsumerImpl::disconnectConsumer() { @@ -1266,7 +1267,7 @@ void ConsumerImpl::closeAsync(ResultCallback originalCallback) { if (ackGroupingTrackerPtr_) { ackGroupingTrackerPtr_->close(); } - negativeAcksTracker_.close(); + negativeAcksTracker_->close(); ClientConnectionPtr cnx = getCnx().lock(); if (!cnx) { @@ -1304,7 +1305,7 @@ void ConsumerImpl::shutdown() { if (client) { client->cleanupConsumer(this); } - negativeAcksTracker_.close(); + negativeAcksTracker_->close(); cancelTimers(); consumerCreatedPromise_.setFailed(ResultAlreadyClosed); failPendingReceiveCallback(); @@ -1609,7 +1610,7 @@ void ConsumerImpl::internalGetLastMessageIdAsync(const BackoffPtr& backoff, Time } void ConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) { - negativeAcksTracker_.setEnabledForTesting(enabled); + negativeAcksTracker_->setEnabledForTesting(enabled); } void ConsumerImpl::trackMessage(const MessageId& messageId) { @@ -1696,6 +1697,7 @@ void ConsumerImpl::cancelTimers() noexcept { boost::system::error_code ec; batchReceiveTimer_->cancel(ec); checkExpiredChunkedTimer_->cancel(ec); + unAckedMessageTrackerPtr_->stop(); } void ConsumerImpl::processPossibleToDLQ(const MessageId& messageId, ProcessDLQCallBack cb) { diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index 61d96b1c..32437091 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -224,7 +224,7 @@ class ConsumerImpl : public ConsumerImplBase { CompressionCodecProvider compressionCodecProvider_; UnAckedMessageTrackerPtr unAckedMessageTrackerPtr_; BrokerConsumerStatsImpl brokerConsumerStats_; - NegativeAcksTracker negativeAcksTracker_; + std::shared_ptr negativeAcksTracker_; AckGroupingTrackerPtr ackGroupingTrackerPtr_; MessageCryptoPtr msgCrypto_; diff --git a/lib/NegativeAcksTracker.cc b/lib/NegativeAcksTracker.cc index 5c3ef3f8..0dd73589 100644 --- a/lib/NegativeAcksTracker.cc +++ b/lib/NegativeAcksTracker.cc @@ -49,8 +49,13 @@ void NegativeAcksTracker::scheduleTimer() { if (closed_) { return; } + std::weak_ptr weakSelf{shared_from_this()}; timer_->expires_from_now(timerInterval_); - timer_->async_wait(std::bind(&NegativeAcksTracker::handleTimer, this, std::placeholders::_1)); + timer_->async_wait([weakSelf](const boost::system::error_code &ec) { + if (auto self = weakSelf.lock()) { + self->handleTimer(ec); + } + }); } void NegativeAcksTracker::handleTimer(const boost::system::error_code &ec) { diff --git a/lib/NegativeAcksTracker.h b/lib/NegativeAcksTracker.h index 029f7d24..4b489844 100644 --- a/lib/NegativeAcksTracker.h +++ b/lib/NegativeAcksTracker.h @@ -40,7 +40,7 @@ using DeadlineTimerPtr = std::shared_ptr; class ExecutorService; using ExecutorServicePtr = std::shared_ptr; -class NegativeAcksTracker { +class NegativeAcksTracker : public std::enable_shared_from_this { public: NegativeAcksTracker(ClientImplPtr client, ConsumerImpl &consumer, const ConsumerConfiguration &conf); diff --git a/lib/PatternMultiTopicsConsumerImpl.cc b/lib/PatternMultiTopicsConsumerImpl.cc index e100a1c3..23e445ee 100644 --- a/lib/PatternMultiTopicsConsumerImpl.cc +++ b/lib/PatternMultiTopicsConsumerImpl.cc @@ -47,8 +47,13 @@ const PULSAR_REGEX_NAMESPACE::regex PatternMultiTopicsConsumerImpl::getPattern() void PatternMultiTopicsConsumerImpl::resetAutoDiscoveryTimer() { autoDiscoveryRunning_ = false; autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod())); - autoDiscoveryTimer_->async_wait( - std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1)); + + auto weakSelf = weak_from_this(); + autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) { + if (auto self = weakSelf.lock()) { + self->autoDiscoveryTimerTask(err); + } + }); } void PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask(const boost::system::error_code& err) { @@ -222,8 +227,12 @@ void PatternMultiTopicsConsumerImpl::start() { if (conf_.getPatternAutoDiscoveryPeriod() > 0) { autoDiscoveryTimer_->expires_from_now(seconds(conf_.getPatternAutoDiscoveryPeriod())); - autoDiscoveryTimer_->async_wait( - std::bind(&PatternMultiTopicsConsumerImpl::autoDiscoveryTimerTask, this, std::placeholders::_1)); + auto weakSelf = weak_from_this(); + autoDiscoveryTimer_->async_wait([weakSelf](const boost::system::error_code& err) { + if (auto self = weakSelf.lock()) { + self->autoDiscoveryTimerTask(err); + } + }); } } diff --git a/lib/PatternMultiTopicsConsumerImpl.h b/lib/PatternMultiTopicsConsumerImpl.h index f13750a9..5d3ba9ec 100644 --- a/lib/PatternMultiTopicsConsumerImpl.h +++ b/lib/PatternMultiTopicsConsumerImpl.h @@ -86,6 +86,10 @@ class PatternMultiTopicsConsumerImpl : public MultiTopicsConsumerImpl { void onTopicsRemoved(NamespaceTopicsPtr removedTopics, ResultCallback callback); void handleOneTopicAdded(const Result result, const std::string& topic, std::shared_ptr> topicsNeedCreate, ResultCallback callback); + + std::weak_ptr weak_from_this() noexcept { + return std::static_pointer_cast(shared_from_this()); + } }; } // namespace pulsar diff --git a/lib/UnAckedMessageTrackerEnabled.cc b/lib/UnAckedMessageTrackerEnabled.cc index ff1b928f..061a1409 100644 --- a/lib/UnAckedMessageTrackerEnabled.cc +++ b/lib/UnAckedMessageTrackerEnabled.cc @@ -35,11 +35,11 @@ void UnAckedMessageTrackerEnabled::timeoutHandler() { ExecutorServicePtr executorService = client_->getIOExecutorProvider()->get(); timer_ = executorService->createDeadlineTimer(); timer_->expires_from_now(boost::posix_time::milliseconds(tickDurationInMs_)); - timer_->async_wait([&](const boost::system::error_code& ec) { - if (ec) { - LOG_DEBUG("Ignoring timer cancelled event, code[" << ec << "]"); - } else { - timeoutHandler(); + std::weak_ptr weakSelf{shared_from_this()}; + timer_->async_wait([weakSelf](const boost::system::error_code& ec) { + auto self = weakSelf.lock(); + if (self && !ec) { + self->timeoutHandler(); } }); } @@ -91,10 +91,10 @@ UnAckedMessageTrackerEnabled::UnAckedMessageTrackerEnabled(long timeoutMs, long std::set msgIds; timePartitions.push_back(msgIds); } - - timeoutHandler(); } +void UnAckedMessageTrackerEnabled::start() { timeoutHandler(); } + bool UnAckedMessageTrackerEnabled::add(const MessageId& msgId) { std::lock_guard acquire(lock_); auto id = discardBatch(msgId); @@ -172,9 +172,10 @@ void UnAckedMessageTrackerEnabled::clear() { } } -UnAckedMessageTrackerEnabled::~UnAckedMessageTrackerEnabled() { +void UnAckedMessageTrackerEnabled::stop() { + boost::system::error_code ec; if (timer_) { - timer_->cancel(); + timer_->cancel(ec); } } } /* namespace pulsar */ diff --git a/lib/UnAckedMessageTrackerEnabled.h b/lib/UnAckedMessageTrackerEnabled.h index 1453460c..6181a8a3 100644 --- a/lib/UnAckedMessageTrackerEnabled.h +++ b/lib/UnAckedMessageTrackerEnabled.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -34,19 +35,21 @@ class ConsumerImplBase; using ClientImplPtr = std::shared_ptr; using DeadlineTimerPtr = std::shared_ptr; -class UnAckedMessageTrackerEnabled : public UnAckedMessageTrackerInterface { +class UnAckedMessageTrackerEnabled : public std::enable_shared_from_this, + public UnAckedMessageTrackerInterface { public: - ~UnAckedMessageTrackerEnabled(); UnAckedMessageTrackerEnabled(long timeoutMs, ClientImplPtr, ConsumerImplBase&); UnAckedMessageTrackerEnabled(long timeoutMs, long tickDuration, ClientImplPtr, ConsumerImplBase&); - bool add(const MessageId& msgId); - bool remove(const MessageId& msgId); - void remove(const MessageIdList& msgIds); - void removeMessagesTill(const MessageId& msgId); - void removeTopicMessage(const std::string& topic); + void start() override; + void stop() override; + bool add(const MessageId& msgId) override; + bool remove(const MessageId& msgId) override; + void remove(const MessageIdList& msgIds) override; + void removeMessagesTill(const MessageId& msgId) override; + void removeTopicMessage(const std::string& topic) override; void timeoutHandler(); - void clear(); + void clear() override; protected: void timeoutHandlerHelper(); diff --git a/lib/UnAckedMessageTrackerInterface.h b/lib/UnAckedMessageTrackerInterface.h index d1fe7893..4df8819e 100644 --- a/lib/UnAckedMessageTrackerInterface.h +++ b/lib/UnAckedMessageTrackerInterface.h @@ -28,6 +28,8 @@ class UnAckedMessageTrackerInterface { public: virtual ~UnAckedMessageTrackerInterface() {} UnAckedMessageTrackerInterface() {} + virtual void start() {} + virtual void stop() {} virtual bool add(const MessageId& m) = 0; virtual bool remove(const MessageId& m) = 0; virtual void remove(const MessageIdList& msgIds) = 0; diff --git a/tests/ConsumerTest.cc b/tests/ConsumerTest.cc index 0836fbf4..c1b342ca 100644 --- a/tests/ConsumerTest.cc +++ b/tests/ConsumerTest.cc @@ -1222,7 +1222,7 @@ TEST(ConsumerTest, testNegativeAcksTrackerClose) { consumer.close(); auto consumerImplPtr = PulsarFriend::getConsumerImplPtr(consumer); - ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_.nackedMessages_.empty()); + ASSERT_TRUE(consumerImplPtr->negativeAcksTracker_->nackedMessages_.empty()); client.close(); }