You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by rg...@apache.org on 2022/03/11 04:35:49 UTC
[pulsar] 06/06: [C++] Fix thread safety issue for multi topic consumer (#14380)
This is an automated email from the ASF dual-hosted git repository.
rgao pushed a commit to branch branch-2.9
in repository https://gitbox.apache.org/repos/asf/pulsar.git
commit a5612df3ea0d597eec16f198b5eab18bf06a0432
Author: Yunze Xu <xy...@163.com>
AuthorDate: Thu Mar 10 16:03:02 2022 +0800
[C++] Fix thread safety issue for multi topic consumer (#14380)
* [C++] Fix thread safety issue for multi topic consumer
**Motivation**
In C++ client, if a consumer subscribes multiple topics, a
`MultiTopicsConsumerImpl` object, which manages a vector of
`ConsumerImpl`s (`consumers_` field), will be created. However,
`consumers_` could be accessed by multiple threads, while no
mutex is locked to protect the access to make it thread safe.
**Modifications**
- Add a `SynchronizedHashMap` class, which implements some thread safe
methods of traverse, remove, find, clear operations. Since the
`forEach` methods could call other methods, use the recursive mutex
instead of the default mutex.
- Add a related test `SynchronizedHashMapTest` to test the methods and
the thread safety of `SynchronizedHashMap`.
- Use `SynchronizedHashMap` as the type of
`MultiTopicsConsumerImpl::consumers_`.
* Add findFirstValueIf method
* Remove unnecessary return value of forEach
* Fix incorrect calls of forEachValue
* Add missed header
(cherry picked from commit f94eba942b9fb3d2c25b6f7a9e2c0885a194efa0)
---
pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc | 166 +++++++++------------
pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h | 5 +-
pulsar-client-cpp/lib/SynchronizedHashMap.h | 127 ++++++++++++++++
pulsar-client-cpp/tests/ConsumerTest.cc | 13 +-
pulsar-client-cpp/tests/SynchronizedHashMapTest.cc | 125 ++++++++++++++++
5 files changed, 335 insertions(+), 101 deletions(-)
diff --git a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc
index 4e31e64..0ae86d5 100644
--- a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc
+++ b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.cc
@@ -171,7 +171,7 @@ void MultiTopicsConsumerImpl::subscribeTopicPartitions(const Result result,
consumer->getConsumerCreatedFuture().addListener(std::bind(
&MultiTopicsConsumerImpl::handleSingleConsumerCreated, shared_from_this(), std::placeholders::_1,
std::placeholders::_2, partitionsNeedCreate, topicSubResultPromise));
- consumers_.insert(std::make_pair(topicName->toString(), consumer));
+ consumers_.emplace(topicName->toString(), consumer);
LOG_DEBUG("Creating Consumer for - " << topicName << " - " << consumerStr_);
consumer->start();
@@ -184,7 +184,7 @@ void MultiTopicsConsumerImpl::subscribeTopicPartitions(const Result result,
&MultiTopicsConsumerImpl::handleSingleConsumerCreated, shared_from_this(),
std::placeholders::_1, std::placeholders::_2, partitionsNeedCreate, topicSubResultPromise));
consumer->setPartitionIndex(i);
- consumers_.insert(std::make_pair(topicPartitionName, consumer));
+ consumers_.emplace(topicPartitionName, consumer);
LOG_DEBUG("Creating Consumer for - " << topicPartitionName << " - " << consumerStr_);
consumer->start();
}
@@ -232,20 +232,19 @@ void MultiTopicsConsumerImpl::unsubscribeAsync(ResultCallback callback) {
state_ = Closing;
lock.unlock();
- if (consumers_.empty()) {
+ std::shared_ptr<std::atomic<int>> consumerUnsubed = std::make_shared<std::atomic<int>>(0);
+ auto self = shared_from_this();
+ int numConsumers = 0;
+ consumers_.forEachValue(
+ [&numConsumers, &consumerUnsubed, &self, callback](const ConsumerImplPtr& consumer) {
+ numConsumers++;
+ consumer->unsubscribeAsync([self, consumerUnsubed, callback](Result result) {
+ self->handleUnsubscribedAsync(result, consumerUnsubed, callback);
+ });
+ });
+ if (numConsumers == 0) {
// No need to unsubscribe, since the list matching the regex was empty
callback(ResultOk);
- return;
- }
-
- std::shared_ptr<std::atomic<int>> consumerUnsubed = std::make_shared<std::atomic<int>>(0);
-
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- (consumer->second)
- ->unsubscribeAsync(std::bind(&MultiTopicsConsumerImpl::handleUnsubscribedAsync,
- shared_from_this(), std::placeholders::_1, consumerUnsubed,
- callback));
}
}
@@ -299,17 +298,17 @@ void MultiTopicsConsumerImpl::unsubscribeOneTopicAsync(const std::string& topic,
for (int i = 0; i < numberPartitions; i++) {
std::string topicPartitionName = topicName->getTopicPartitionName(i);
- std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName);
-
- if (consumers_.end() == iterator) {
+ auto optConsumer = consumers_.find(topicPartitionName);
+ if (optConsumer.is_empty()) {
LOG_ERROR("TopicsConsumer not subscribed on topicPartitionName: " << topicPartitionName);
callback(ResultUnknownError);
+ continue;
}
- (iterator->second)
- ->unsubscribeAsync(std::bind(&MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync,
- shared_from_this(), std::placeholders::_1, consumerUnsubed,
- numberPartitions, topicName, topicPartitionName, callback));
+ optConsumer.value()->unsubscribeAsync(
+ std::bind(&MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync, shared_from_this(),
+ std::placeholders::_1, consumerUnsubed, numberPartitions, topicName, topicPartitionName,
+ callback));
}
}
@@ -326,10 +325,9 @@ void MultiTopicsConsumerImpl::handleOneTopicUnsubscribedAsync(
LOG_DEBUG("Successfully Unsubscribed one Consumer. topicPartitionName - " << topicPartitionName);
- std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName);
- if (consumers_.end() != iterator) {
- iterator->second->pauseMessageListener();
- consumers_.erase(iterator);
+ auto optConsumer = consumers_.remove(topicPartitionName);
+ if (optConsumer.is_present()) {
+ optConsumer.value()->pauseMessageListener();
}
if (consumerUnsubed->load() == numberPartitions) {
@@ -363,7 +361,16 @@ void MultiTopicsConsumerImpl::closeAsync(ResultCallback callback) {
setState(Closing);
- if (consumers_.empty()) {
+ auto self = shared_from_this();
+ int numConsumers = 0;
+ consumers_.forEach(
+ [&numConsumers, &self, callback](const std::string& name, const ConsumerImplPtr& consumer) {
+ numConsumers++;
+ consumer->closeAsync([self, name, callback](Result result) {
+ self->handleSingleConsumerClose(result, name, callback);
+ });
+ });
+ if (numConsumers == 0) {
LOG_DEBUG("TopicsConsumer have no consumers to close "
<< " topic" << topic_ << " subscription - " << subscriptionName_);
setState(Closed);
@@ -373,27 +380,13 @@ void MultiTopicsConsumerImpl::closeAsync(ResultCallback callback) {
return;
}
- // close successfully subscribed consumers
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- std::string topicPartitionName = consumer->first;
- ConsumerImplPtr consumerPtr = consumer->second;
-
- consumerPtr->closeAsync(std::bind(&MultiTopicsConsumerImpl::handleSingleConsumerClose,
- shared_from_this(), std::placeholders::_1, topicPartitionName,
- callback));
- }
-
// fail pending recieve
failPendingReceiveCallback();
}
-void MultiTopicsConsumerImpl::handleSingleConsumerClose(Result result, std::string& topicPartitionName,
+void MultiTopicsConsumerImpl::handleSingleConsumerClose(Result result, std::string topicPartitionName,
CloseCallback callback) {
- std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName);
- if (consumers_.end() != iterator) {
- consumers_.erase(iterator);
- }
+ consumers_.remove(topicPartitionName);
LOG_DEBUG("Closing the consumer for partition - " << topicPartitionName << " numberTopicPartitions_ - "
<< numberTopicPartitions_->load());
@@ -543,15 +536,14 @@ void MultiTopicsConsumerImpl::acknowledgeAsync(const MessageId& msgId, ResultCal
}
const std::string& topicPartitionName = msgId.getTopicName();
- std::map<std::string, ConsumerImplPtr>::iterator iterator = consumers_.find(topicPartitionName);
+ auto optConsumer = consumers_.find(topicPartitionName);
- if (consumers_.end() != iterator) {
+ if (optConsumer.is_present()) {
unAckedMessageTrackerPtr_->remove(msgId);
- iterator->second->acknowledgeAsync(msgId, callback);
+ optConsumer.value()->acknowledgeAsync(msgId, callback);
} else {
LOG_ERROR("Message of topic: " << topicPartitionName << " not in unAckedMessageTracker");
callback(ResultUnknownError);
- return;
}
}
@@ -560,11 +552,11 @@ void MultiTopicsConsumerImpl::acknowledgeCumulativeAsync(const MessageId& msgId,
}
void MultiTopicsConsumerImpl::negativeAcknowledge(const MessageId& msgId) {
- auto iterator = consumers_.find(msgId.getTopicName());
+ auto optConsumer = consumers_.find(msgId.getTopicName());
- if (consumers_.end() != iterator) {
+ if (optConsumer.is_present()) {
unAckedMessageTrackerPtr_->remove(msgId);
- iterator->second->negativeAcknowledge(msgId);
+ optConsumer.value()->negativeAcknowledge(msgId);
}
}
@@ -605,22 +597,18 @@ bool MultiTopicsConsumerImpl::isOpen() {
}
void MultiTopicsConsumerImpl::receiveMessages() {
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- ConsumerImplPtr consumerPtr = consumer->second;
- consumerPtr->sendFlowPermitsToBroker(consumerPtr->getCnx().lock(), conf_.getReceiverQueueSize());
- LOG_DEBUG("Sending FLOW command for consumer - " << consumerPtr->getConsumerId());
- }
+ const auto receiverQueueSize = conf_.getReceiverQueueSize();
+ consumers_.forEachValue([receiverQueueSize](const ConsumerImplPtr& consumer) {
+ consumer->sendFlowPermitsToBroker(consumer->getCnx().lock(), receiverQueueSize);
+ LOG_DEBUG("Sending FLOW command for consumer - " << consumer->getConsumerId());
+ });
}
Result MultiTopicsConsumerImpl::pauseMessageListener() {
if (!messageListener_) {
return ResultInvalidConfiguration;
}
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- (consumer->second)->pauseMessageListener();
- }
+ consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->pauseMessageListener(); });
return ResultOk;
}
@@ -628,19 +616,14 @@ Result MultiTopicsConsumerImpl::resumeMessageListener() {
if (!messageListener_) {
return ResultInvalidConfiguration;
}
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- (consumer->second)->resumeMessageListener();
- }
+ consumers_.forEachValue([](const ConsumerImplPtr& consumer) { consumer->resumeMessageListener(); });
return ResultOk;
}
void MultiTopicsConsumerImpl::redeliverUnacknowledgedMessages() {
LOG_DEBUG("Sending RedeliverUnacknowledgedMessages command for partitioned consumer.");
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- (consumer->second)->redeliverUnacknowledgedMessages();
- }
+ consumers_.forEachValue(
+ [](const ConsumerImplPtr& consumer) { consumer->redeliverUnacknowledgedMessages(); });
unAckedMessageTrackerPtr_->clear();
}
@@ -653,10 +636,9 @@ void MultiTopicsConsumerImpl::redeliverUnacknowledgedMessages(const std::set<Mes
return;
}
LOG_DEBUG("Sending RedeliverUnacknowledgedMessages command for partitioned consumer.");
- for (ConsumerMap::const_iterator consumer = consumers_.begin(); consumer != consumers_.end();
- consumer++) {
- (consumer->second)->redeliverUnacknowledgedMessages(messageIds);
- }
+ consumers_.forEachValue([&messageIds](const ConsumerImplPtr& consumer) {
+ consumer->redeliverUnacknowledgedMessages(messageIds);
+ });
}
int MultiTopicsConsumerImpl::getNumOfPrefetchedMessages() const { return messages_.size(); }
@@ -671,15 +653,17 @@ void MultiTopicsConsumerImpl::getBrokerConsumerStatsAsync(BrokerConsumerStatsCal
MultiTopicsBrokerConsumerStatsPtr statsPtr =
std::make_shared<MultiTopicsBrokerConsumerStatsImpl>(numberTopicPartitions_->load());
LatchPtr latchPtr = std::make_shared<Latch>(numberTopicPartitions_->load());
- int size = consumers_.size();
lock.unlock();
- ConsumerMap::const_iterator consumer = consumers_.begin();
- for (int i = 0; i < size; i++, consumer++) {
- consumer->second->getBrokerConsumerStatsAsync(
- std::bind(&MultiTopicsConsumerImpl::handleGetConsumerStats, shared_from_this(),
- std::placeholders::_1, std::placeholders::_2, latchPtr, statsPtr, i, callback));
- }
+ auto self = shared_from_this();
+ size_t i = 0;
+ consumers_.forEachValue([&self, &latchPtr, &statsPtr, &i, callback](const ConsumerImplPtr& consumer) {
+ size_t index = i++;
+ consumer->getBrokerConsumerStatsAsync(
+ [self, latchPtr, statsPtr, index, callback](Result result, BrokerConsumerStats stats) {
+ self->handleGetConsumerStats(result, stats, latchPtr, statsPtr, index, callback);
+ });
+ });
}
void MultiTopicsConsumerImpl::handleGetConsumerStats(Result res, BrokerConsumerStats brokerConsumerStats,
@@ -725,10 +709,9 @@ void MultiTopicsConsumerImpl::seekAsync(uint64_t timestamp, ResultCallback callb
}
void MultiTopicsConsumerImpl::setNegativeAcknowledgeEnabledForTesting(bool enabled) {
- Lock lock(mutex_);
- for (auto&& c : consumers_) {
- c.second->setNegativeAcknowledgeEnabledForTesting(enabled);
- }
+ consumers_.forEachValue([enabled](const ConsumerImplPtr& consumer) {
+ consumer->setNegativeAcknowledgeEnabledForTesting(enabled);
+ });
}
bool MultiTopicsConsumerImpl::isConnected() const {
@@ -736,24 +719,19 @@ bool MultiTopicsConsumerImpl::isConnected() const {
if (state_ != Ready) {
return false;
}
+ lock.unlock();
- for (const auto& topicAndConsumer : consumers_) {
- if (!topicAndConsumer.second->isConnected()) {
- return false;
- }
- }
- return true;
+ return consumers_
+ .findFirstValueIf([](const ConsumerImplPtr& consumer) { return !consumer->isConnected(); })
+ .is_empty();
}
uint64_t MultiTopicsConsumerImpl::getNumberOfConnectedConsumer() {
- Lock lock(mutex_);
uint64_t numberOfConnectedConsumer = 0;
- const auto consumers = consumers_;
- lock.unlock();
- for (const auto& topicAndConsumer : consumers) {
- if (topicAndConsumer.second->isConnected()) {
+ consumers_.forEachValue([&numberOfConnectedConsumer](const ConsumerImplPtr& consumer) {
+ if (consumer->isConnected()) {
numberOfConnectedConsumer++;
}
- }
+ });
return numberOfConnectedConsumer;
}
diff --git a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h
index aa6b261..98b2f31 100644
--- a/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h
+++ b/pulsar-client-cpp/lib/MultiTopicsConsumerImpl.h
@@ -32,6 +32,7 @@
#include <lib/MultiTopicsBrokerConsumerStatsImpl.h>
#include <lib/TopicName.h>
#include <lib/NamespaceName.h>
+#include <lib/SynchronizedHashMap.h>
namespace pulsar {
typedef std::shared_ptr<Promise<Result, Consumer>> ConsumerSubResultPromisePtr;
@@ -93,7 +94,7 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase,
std::string consumerStr_;
std::string topic_;
const ConsumerConfiguration conf_;
- typedef std::map<std::string, ConsumerImplPtr> ConsumerMap;
+ typedef SynchronizedHashMap<std::string, ConsumerImplPtr> ConsumerMap;
ConsumerMap consumers_;
std::map<std::string, int> topicsPartitions_;
mutable std::mutex mutex_;
@@ -115,7 +116,7 @@ class MultiTopicsConsumerImpl : public ConsumerImplBase,
void handleSinglePartitionConsumerCreated(Result result, ConsumerImplBaseWeakPtr consumerImplBaseWeakPtr,
unsigned int partitionIndex);
- void handleSingleConsumerClose(Result result, std::string& topicPartitionName, CloseCallback callback);
+ void handleSingleConsumerClose(Result result, std::string topicPartitionName, CloseCallback callback);
void notifyResult(CloseCallback closeCallback);
void messageReceived(Consumer consumer, const Message& msg);
void internalListener(Consumer consumer);
diff --git a/pulsar-client-cpp/lib/SynchronizedHashMap.h b/pulsar-client-cpp/lib/SynchronizedHashMap.h
new file mode 100644
index 0000000..3a78467
--- /dev/null
+++ b/pulsar-client-cpp/lib/SynchronizedHashMap.h
@@ -0,0 +1,127 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+#include <functional>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "Utils.h"
+
+namespace pulsar {
+
+// V must be default constructible and copyable
+template <typename K, typename V>
+class SynchronizedHashMap {
+ using MutexType = std::recursive_mutex;
+ using Lock = std::lock_guard<MutexType>;
+
+ public:
+ using OptValue = Optional<V>;
+ using PairVector = std::vector<std::pair<K, V>>;
+
+ SynchronizedHashMap() = default;
+
+ SynchronizedHashMap(const PairVector& pairs) {
+ for (auto&& kv : pairs) {
+ data_.emplace(kv.first, kv.second);
+ }
+ }
+
+ template <typename... Args>
+ void emplace(Args&&... args) {
+ Lock lock(mutex_);
+ data_.emplace(std::forward<Args>(args)...);
+ }
+
+ void forEach(std::function<void(const K&, const V&)> f) const {
+ Lock lock(mutex_);
+ for (const auto& kv : data_) {
+ f(kv.first, kv.second);
+ }
+ }
+
+ void forEachValue(std::function<void(const V&)> f) const {
+ Lock lock(mutex_);
+ for (const auto& kv : data_) {
+ f(kv.second);
+ }
+ }
+
+ void clear() {
+ Lock lock(mutex_);
+ data_.clear();
+ }
+
+ OptValue find(const K& key) const {
+ Lock lock(mutex_);
+ auto it = data_.find(key);
+ if (it != data_.end()) {
+ return OptValue::of(it->second);
+ } else {
+ return OptValue::empty();
+ }
+ }
+
+ OptValue findFirstValueIf(std::function<bool(const V&)> f) const {
+ Lock lock(mutex_);
+ for (const auto& kv : data_) {
+ if (f(kv.second)) {
+ return OptValue::of(kv.second);
+ }
+ }
+ return OptValue::empty();
+ }
+
+ OptValue remove(const K& key) {
+ Lock lock(mutex_);
+ auto it = data_.find(key);
+ if (it != data_.end()) {
+ auto result = OptValue::of(it->second);
+ data_.erase(it);
+ return result;
+ } else {
+ return OptValue::empty();
+ }
+ }
+
+ // This method is only used for test
+ PairVector toPairVector() const {
+ Lock lock(mutex_);
+ PairVector pairs;
+ for (auto&& kv : data_) {
+ pairs.emplace_back(kv);
+ }
+ return pairs;
+ }
+
+ // This method is only used for test
+ size_t size() const noexcept {
+ Lock lock(mutex_);
+ return data_.size();
+ }
+
+ private:
+ std::unordered_map<K, V> data_;
+ // Use recursive_mutex to allow methods being called in `forEach`
+ mutable MutexType mutex_;
+};
+
+} // namespace pulsar
diff --git a/pulsar-client-cpp/tests/ConsumerTest.cc b/pulsar-client-cpp/tests/ConsumerTest.cc
index 100086e..b61c15a 100644
--- a/pulsar-client-cpp/tests/ConsumerTest.cc
+++ b/pulsar-client-cpp/tests/ConsumerTest.cc
@@ -530,11 +530,14 @@ TEST(ConsumerTest, testMultiTopicsConsumerUnAckedMessageRedelivery) {
multiTopicsConsumerImplPtr->unAckedMessageTrackerPtr_.get());
ASSERT_EQ(numOfMessages * 3, multiTopicsTracker->size());
ASSERT_FALSE(multiTopicsTracker->isEmpty());
- for (auto iter = multiTopicsConsumerImplPtr->consumers_.begin();
- iter != multiTopicsConsumerImplPtr->consumers_.end(); ++iter) {
- auto subConsumerPtr = iter->second;
- auto tracker =
- static_cast<UnAckedMessageTrackerEnabled*>(subConsumerPtr->unAckedMessageTrackerPtr_.get());
+
+ std::vector<UnAckedMessageTrackerEnabled*> trackers;
+ multiTopicsConsumerImplPtr->consumers_.forEach(
+ [&trackers](const std::string& name, const ConsumerImplPtr& consumer) {
+ trackers.emplace_back(
+ static_cast<UnAckedMessageTrackerEnabled*>(consumer->unAckedMessageTrackerPtr_.get()));
+ });
+ for (const auto& tracker : trackers) {
ASSERT_EQ(0, tracker->size());
ASSERT_TRUE(tracker->isEmpty());
}
diff --git a/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc b/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc
new file mode 100644
index 0000000..62c55c4
--- /dev/null
+++ b/pulsar-client-cpp/tests/SynchronizedHashMapTest.cc
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <algorithm>
+#include <atomic>
+#include <chrono>
+#include <thread>
+#include <vector>
+#include "lib/Latch.h"
+#include "lib/SynchronizedHashMap.h"
+
+using namespace pulsar;
+using SyncMapType = SynchronizedHashMap<int, int>;
+using OptValue = typename SyncMapType::OptValue;
+using PairVector = typename SyncMapType::PairVector;
+
+inline void sleepMs(long millis) { std::this_thread::sleep_for(std::chrono::milliseconds(millis)); }
+
+inline PairVector sort(PairVector pairs) {
+ std::sort(pairs.begin(), pairs.end(), [](const std::pair<int, int>& lhs, const std::pair<int, int>& rhs) {
+ return lhs.first < rhs.first;
+ });
+ return pairs;
+}
+
+TEST(SynchronizedHashMap, testClear) {
+ SynchronizedHashMap<int, int> m({{1, 100}, {2, 200}});
+ m.clear();
+ ASSERT_EQ(m.toPairVector(), PairVector{});
+}
+
+TEST(SynchronizedHashMap, testRemoveAndFind) {
+ SyncMapType m({{1, 100}, {2, 200}, {3, 300}});
+
+ OptValue optValue;
+ optValue = m.findFirstValueIf([](const int& x) { return x == 200; });
+ ASSERT_TRUE(optValue.is_present());
+ ASSERT_EQ(optValue.value(), 200);
+
+ optValue = m.findFirstValueIf([](const int& x) { return x >= 301; });
+ ASSERT_FALSE(optValue.is_present());
+
+ optValue = m.find(1);
+ ASSERT_TRUE(optValue.is_present());
+ ASSERT_EQ(optValue.value(), 100);
+
+ ASSERT_FALSE(m.find(0).is_present());
+ ASSERT_FALSE(m.remove(0).is_present());
+
+ optValue = m.remove(1);
+ ASSERT_TRUE(optValue.is_present());
+ ASSERT_EQ(optValue.value(), 100);
+
+ ASSERT_FALSE(m.remove(1).is_present());
+ ASSERT_FALSE(m.find(1).is_present());
+}
+
+TEST(SynchronizedHashMapTest, testForEach) {
+ SyncMapType m({{1, 100}, {2, 200}, {3, 300}});
+ std::vector<int> values;
+ m.forEachValue([&values](const int& value) { values.emplace_back(value); });
+ std::sort(values.begin(), values.end());
+ ASSERT_EQ(values, std::vector<int>({100, 200, 300}));
+
+ PairVector pairs;
+ m.forEach([&pairs](const int& key, const int& value) { pairs.emplace_back(key, value); });
+ PairVector expectedPairs({{1, 100}, {2, 200}, {3, 300}});
+ ASSERT_EQ(sort(pairs), expectedPairs);
+}
+
+TEST(SynchronizedHashMap, testRecursiveMutex) {
+ SyncMapType m({{1, 100}});
+ OptValue optValue;
+ m.forEach([&m, &optValue](const int& key, const int& value) {
+ optValue = m.find(key); // the internal mutex was locked again
+ });
+ ASSERT_TRUE(optValue.is_present());
+ ASSERT_EQ(optValue.value(), 100);
+}
+
+TEST(SynchronizedHashMapTest, testThreadSafeForEach) {
+ SyncMapType m({{1, 100}, {2, 200}, {3, 300}});
+
+ Latch latch(1);
+ std::thread t{[&m, &latch] {
+ latch.wait(); // this thread must start after `m.forEach` started
+ m.remove(2);
+ }};
+
+ std::atomic_bool firstElementDone{false};
+ PairVector pairs;
+ m.forEach([&latch, &firstElementDone, &pairs](const int& key, const int& value) {
+ pairs.emplace_back(key, value);
+ if (!firstElementDone) {
+ latch.countdown();
+ firstElementDone = true;
+ }
+ sleepMs(200);
+ });
+ {
+ PairVector expectedPairs({{1, 100}, {2, 200}, {3, 300}});
+ ASSERT_EQ(sort(pairs), expectedPairs);
+ }
+ t.join();
+ {
+ PairVector expectedPairs({{1, 100}, {3, 300}});
+ ASSERT_EQ(sort(m.toPairVector()), expectedPairs);
+ }
+}