This is an automated email from the ASF dual-hosted git repository.

binbin pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new aaff696e Split Inter in ZSet::InterStore into a separate function 
(#1726)
aaff696e is described below

commit aaff696e76b6a7ff5921667136aa241da47b39eb
Author: Binbin <[email protected]>
AuthorDate: Sun Sep 3 18:12:50 2023 +0800

    Split Inter in ZSet::InterStore into a separate function (#1726)
    
    CI TSAN show ZINTERSTORE may has a deadlock after introducing
    locks to DEL in #1712. In ZSet::InterStore if the dst key was
    inside the source key list we may have a deadlock since the
    OverWrite function will also lock the dst key.
    
    In this PR, we split Inter in ZSet::InterStore into a separate
    function, just like the Set apis.
    
    After this PR, after the CI verification in #1712, it can pass
    the CI verification now. Closes #1715
    
    This PR also do a saved_cnt cleanup since it is same as members.size().
---
 src/commands/cmd_zset.cc |  2 +-
 src/types/redis_zset.cc  | 27 +++++++++++++++------------
 src/types/redis_zset.h   |  4 +++-
 3 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index 70439e68..52c9a3d4 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -1228,7 +1228,7 @@ class CommandZUnion : public Commander {
   Status Execute(Server *svr, Connection *conn, std::string *output) override {
     redis::ZSet zset_db(svr->storage, conn->GetNamespace());
     std::vector<MemberScore> member_scores;
-    auto s = zset_db.Union(keys_weights_, aggregate_method_, nullptr, 
&member_scores);
+    auto s = zset_db.Union(keys_weights_, aggregate_method_, &member_scores);
     if (!s.ok()) {
       return {Status::RedisExecErr, s.ToString()};
     }
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 09495151..c99d3c8a 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -626,6 +626,16 @@ rocksdb::Status ZSet::Overwrite(const Slice &user_key, 
const MemberScores &mscor
 
 rocksdb::Status ZSet::InterStore(const Slice &dst, const 
std::vector<KeyWeight> &keys_weights,
                                  AggregateMethod aggregate_method, uint64_t 
*saved_cnt) {
+  *saved_cnt = 0;
+  std::vector<MemberScore> members;
+  auto s = Inter(keys_weights, aggregate_method, &members);
+  if (!s.ok()) return s;
+  *saved_cnt = members.size();
+  return Overwrite(dst, members);
+}
+
+rocksdb::Status ZSet::Inter(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
+                            std::vector<MemberScore> *members) {
   std::vector<std::string> lock_keys;
   lock_keys.reserve(keys_weights.size());
   for (const auto &key_weight : keys_weights) {
@@ -634,8 +644,6 @@ rocksdb::Status ZSet::InterStore(const Slice &dst, const 
std::vector<KeyWeight>
   }
   MultiLockGuard guard(storage_->GetLockManager(), lock_keys);
 
-  if (saved_cnt) *saved_cnt = 0;
-
   std::map<std::string, double> dst_zset;
   std::map<std::string, size_t> member_counters;
   std::vector<MemberScore> target_mscores;
@@ -680,14 +688,12 @@ rocksdb::Status ZSet::InterStore(const Slice &dst, const 
std::vector<KeyWeight>
       }
     }
   }
-  if (!dst_zset.empty()) {
-    std::vector<MemberScore> mscores;
+  if (members && !dst_zset.empty()) {
+    members->reserve(dst_zset.size());
     for (const auto &iter : dst_zset) {
       if (member_counters[iter.first] != keys_weights.size()) continue;
-      mscores.emplace_back(MemberScore{iter.first, iter.second});
+      members->emplace_back(MemberScore{iter.first, iter.second});
     }
-    if (saved_cnt) *saved_cnt = mscores.size();
-    Overwrite(dst, mscores);
   }
 
   return rocksdb::Status::OK();
@@ -697,14 +703,14 @@ rocksdb::Status ZSet::UnionStore(const Slice &dst, const 
std::vector<KeyWeight>
                                  AggregateMethod aggregate_method, uint64_t 
*saved_cnt) {
   *saved_cnt = 0;
   std::vector<MemberScore> members;
-  auto s = Union(keys_weights, aggregate_method, saved_cnt, &members);
+  auto s = Union(keys_weights, aggregate_method, &members);
   if (!s.ok()) return s;
   *saved_cnt = members.size();
   return Overwrite(dst, members);
 }
 
 rocksdb::Status ZSet::Union(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
-                            uint64_t *saved_cnt, std::vector<MemberScore> 
*members) {
+                            std::vector<MemberScore> *members) {
   std::vector<std::string> lock_keys;
   lock_keys.reserve(keys_weights.size());
   for (const auto &key_weight : keys_weights) {
@@ -713,8 +719,6 @@ rocksdb::Status ZSet::Union(const std::vector<KeyWeight> 
&keys_weights, Aggregat
   }
   MultiLockGuard guard(storage_->GetLockManager(), lock_keys);
 
-  if (saved_cnt) *saved_cnt = 0;
-
   std::map<std::string, double> dst_zset;
   std::vector<MemberScore> target_mscores;
   uint64_t target_size = 0;
@@ -753,7 +757,6 @@ rocksdb::Status ZSet::Union(const std::vector<KeyWeight> 
&keys_weights, Aggregat
     for (const auto &iter : dst_zset) {
       members->emplace_back(MemberScore{iter.first, iter.second});
     }
-    if (saved_cnt) *saved_cnt = members->size();
   }
   return rocksdb::Status::OK();
 }
diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h
index 1ea0b192..9768d010 100644
--- a/src/types/redis_zset.h
+++ b/src/types/redis_zset.h
@@ -109,10 +109,12 @@ class ZSet : public SubKeyScanner {
   rocksdb::Status Overwrite(const Slice &user_key, const MemberScores 
&mscores);
   rocksdb::Status InterStore(const Slice &dst, const std::vector<KeyWeight> 
&keys_weights,
                              AggregateMethod aggregate_method, uint64_t 
*saved_cnt);
+  rocksdb::Status Inter(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
+                        std::vector<MemberScore> *members);
   rocksdb::Status UnionStore(const Slice &dst, const std::vector<KeyWeight> 
&keys_weights,
                              AggregateMethod aggregate_method, uint64_t 
*saved_cnt);
   rocksdb::Status Union(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
-                        uint64_t *saved_cnt, std::vector<MemberScore> 
*members);
+                        std::vector<MemberScore> *members);
   rocksdb::Status MGet(const Slice &user_key, const std::vector<Slice> 
&members, std::map<std::string, double> *scores);
   rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata);
 

Reply via email to