This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new ee41959e Add the support of ZINTER and ZINTERCARD (#1992)
ee41959e is described below

commit ee41959e837a5a3023a9eb6aef4914567e9fdd64
Author: kay011 <[email protected]>
AuthorDate: Tue Jan 9 21:40:48 2024 +0800

    Add the support of ZINTER and ZINTERCARD (#1992)
---
 src/commands/cmd_zset.cc                 | 83 ++++++++++++++++++++++++++++++++
 src/types/redis_zset.cc                  | 46 ++++++++++++++++++
 src/types/redis_zset.h                   |  1 +
 tests/gocase/unit/type/zset/zset_test.go | 52 ++++++++++++++++++++
 4 files changed, 182 insertions(+)

diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index ab4909f4..1fa51ab2 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -1252,6 +1252,87 @@ class CommandZInterStore : public CommandZUnionStore {
   }
 };
 
+class CommandZInter : public CommandZUnion {
+ public:
+  CommandZInter() : CommandZUnion() {}
+
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+    std::vector<MemberScore> member_scores;
+    auto s = zset_db.Inter(keys_weights_, aggregate_method_, &member_scores);
+    if (!s.ok()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+    auto ms_comparator = [](const MemberScore &ms1, const MemberScore &ms2) {
+      if (ms1.score == ms2.score) {
+        return ms1.member < ms2.member;
+      }
+      return ms1.score < ms2.score;
+    };
+    std::sort(member_scores.begin(), member_scores.end(), ms_comparator);
+    output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 
1)));
+    for (const auto &member_score : member_scores) {
+      output->append(redis::BulkString(member_score.member));
+      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(member_score.score)));
+    }
+    return Status::OK();
+  }
+
+  static CommandKeyRange Range(const std::vector<std::string> &args) {
+    int num_key = *ParseInt<int>(args[1], 10);
+    return {2, 1 + num_key, 1};
+  }
+};
+
+class CommandZInterCard : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    CommandParser parser(args, 1);
+    numkeys_ = GET_OR_RET(parser.TakeInt<int>(NumericRange<int>{1, 
std::numeric_limits<int>::max()}));
+    for (size_t i = 0; i < numkeys_; ++i) {
+      keys_.emplace_back(GET_OR_RET(parser.TakeStr()));
+    }
+
+    // if set limit option
+    if (parser.Good()) {
+      if (parser.EatEqICase("limit")) {
+        auto res = parser.TakeInt<int64_t>();
+        if (!res.IsOK() || res.GetValue() < 0) {
+          return {Status::RedisParseErr, errLimitIsNegative};
+        }
+        limit_ = static_cast<size_t>(res.GetValue());
+        if (parser.Good()) {
+          return parser.InvalidSyntax();
+        }
+      } else {
+        return parser.InvalidSyntax();
+      }
+    }
+
+    return Commander::Parse(args);
+  }
+  Status Execute(Server *srv, Connection *conn, std::string *output) override {
+    redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+    uint64_t count = 0;
+    auto s = zset_db.InterCard(keys_, limit_, &count);
+    if (!s.ok()) {
+      return {Status::RedisExecErr, s.ToString()};
+    }
+    *output = redis::Integer(count);
+    return Status::OK();
+  }
+
+  static CommandKeyRange Range(const std::vector<std::string> &args) {
+    int num_key = *ParseInt<int>(args[1], 10);
+    return {2, 1 + num_key, 1};
+  }
+
+ private:
+  size_t numkeys_ = 0;
+  size_t limit_ = 0;
+  std::vector<std::string> keys_;
+};
+
 class CommandZScan : public CommandSubkeyScanBase {
  public:
   CommandZScan() = default;
@@ -1281,6 +1362,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", 
-4, "write", 1, 1, 1),
                         MakeCmdAttr<CommandZCount>("zcount", 4, "read-only", 
1, 1, 1),
                         MakeCmdAttr<CommandZIncrBy>("zincrby", 4, "write", 1, 
1, 1),
                         MakeCmdAttr<CommandZInterStore>("zinterstore", -4, 
"write", CommandZInterStore::Range),
+                        MakeCmdAttr<CommandZInter>("zinter", -3, "read-only", 
CommandZInter::Range),
+                        MakeCmdAttr<CommandZInterCard>("zintercard", -3, 
"read-only", CommandZInterCard::Range),
                         MakeCmdAttr<CommandZLexCount>("zlexcount", 4, 
"read-only", 1, 1, 1),
                         MakeCmdAttr<CommandZPopMax>("zpopmax", -2, "write", 1, 
1, 1),
                         MakeCmdAttr<CommandZPopMin>("zpopmin", -2, "write", 1, 
1, 1),
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 92bf5be7..f2976544 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -699,6 +699,52 @@ rocksdb::Status ZSet::Inter(const std::vector<KeyWeight> 
&keys_weights, Aggregat
   return rocksdb::Status::OK();
 }
 
+rocksdb::Status ZSet::InterCard(const std::vector<std::string> &user_keys, 
uint64_t limit, uint64_t *inter_cnt) {
+  std::vector<std::string> lock_keys;
+  lock_keys.reserve(user_keys.size());
+  for (const auto &user_key : user_keys) {
+    std::string ns_key = AppendNamespacePrefix(user_key);
+    lock_keys.emplace_back(std::move(ns_key));
+  }
+  MultiLockGuard guard(storage_->GetLockManager(), lock_keys);
+
+  std::vector<MemberScores> mscores_list;
+  mscores_list.reserve(user_keys.size());
+  RangeScoreSpec spec;
+  for (const auto &user_key : user_keys) {
+    MemberScores mscores;
+    auto s = RangeByScore(user_key, spec, &mscores, nullptr);
+    if (!s.ok() || mscores.empty()) return s;
+    mscores_list.emplace_back(mscores);
+  }
+  std::sort(mscores_list.begin(), mscores_list.end(),
+            [](const MemberScores &v1, const MemberScores &v2) { return 
v1.size() < v2.size(); });
+
+  auto base_mscores = mscores_list[0];
+  std::map<std::string, size_t> member_counters;
+  uint64_t cardinality = 0;
+  for (const auto &base_ms : base_mscores) {
+    member_counters[base_ms.member] = 1;
+    for (size_t i = 1; i < mscores_list.size(); i++) {
+      for (const auto &ms : mscores_list[i]) {
+        if (base_ms.member == ms.member) {
+          member_counters[ms.member]++;
+          break;
+        }
+      }
+    }
+    if (member_counters[base_ms.member] == mscores_list.size()) {
+      cardinality++;
+      if (limit > 0 && cardinality >= limit) {
+        *inter_cnt = limit;
+        return rocksdb::Status::OK();
+      };
+    }
+  }
+  *inter_cnt = cardinality;
+  return rocksdb::Status::OK();
+}
+
 rocksdb::Status ZSet::UnionStore(const Slice &dst, const 
std::vector<KeyWeight> &keys_weights,
                                  AggregateMethod aggregate_method, uint64_t 
*saved_cnt) {
   *saved_cnt = 0;
diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h
index 9768d010..3cd81622 100644
--- a/src/types/redis_zset.h
+++ b/src/types/redis_zset.h
@@ -111,6 +111,7 @@ class ZSet : public SubKeyScanner {
                              AggregateMethod aggregate_method, uint64_t 
*saved_cnt);
   rocksdb::Status Inter(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
                         std::vector<MemberScore> *members);
+  rocksdb::Status InterCard(const std::vector<std::string> &user_keys, 
uint64_t limit, uint64_t *inter_cnt);
   rocksdb::Status UnionStore(const Slice &dst, const std::vector<KeyWeight> 
&keys_weights,
                              AggregateMethod aggregate_method, uint64_t 
*saved_cnt);
   rocksdb::Status Union(const std::vector<KeyWeight> &keys_weights, 
AggregateMethod aggregate_method,
diff --git a/tests/gocase/unit/type/zset/zset_test.go 
b/tests/gocase/unit/type/zset/zset_test.go
index d1590425..860316b2 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -1236,6 +1236,58 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
                require.Equal(t, []redis.Z{{2, "b"}, {3, "c"}}, 
rdb.ZRangeWithScores(ctx, "zsetc", 0, -1).Val())
        })
 
+       t.Run(fmt.Sprintf("ZINTER with AGGREGATE and WEIGHTS - %s", encoding), 
func(t *testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 3, Member: "d"},
+               })
+
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}}).Val())
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "max"}).Val())
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "min"}).Val())
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "sum"}).Val())
+
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3}, 
Aggregate: "sum"}).Val())
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3}, 
Aggregate: "max"}).Val())
+               require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3}, 
Aggregate: "min"}).Val())
+
+               require.Equal(t, []redis.Z{{3, "b"}, {5, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", 
"zsetb"}}).Val())
+               require.Equal(t, []redis.Z{{2, "b"}, {3, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Aggregate: "max"}).Val())
+               require.Equal(t, []redis.Z{{1, "b"}, {2, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Aggregate: "min"}).Val())
+               require.Equal(t, []redis.Z{{3, "b"}, {5, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Aggregate: "sum"}).Val())
+
+               require.Equal(t, []redis.Z{{7, "b"}, {12, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Weights: []float64{2, 3}, Aggregate: "sum"}).Val())
+               require.Equal(t, []redis.Z{{4, "b"}, {6, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Weights: []float64{2, 3}, Aggregate: "max"}).Val())
+               require.Equal(t, []redis.Z{{3, "b"}, {6, "c"}}, 
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"}, 
Weights: []float64{2, 3}, Aggregate: "min"}).Val())
+
+               require.Equal(t, 0, len(rdb.ZInter(ctx, &redis.ZStore{Keys: 
[]string{"zseta", "zsetb", "zset_noexists"}}).Val()))
+               require.Equal(t, 0, len(rdb.ZInterWithScores(ctx, 
&redis.ZStore{Keys: []string{"zseta", "zsetb", "zset_noexists"}, Weights: 
[]float64{2, 3}, Aggregate: "sum"}).Val()))
+       })
+
+       t.Run(fmt.Sprintf("ZINTERCARD with LIMIT - %s", encoding), func(t 
*testing.T) {
+               createZset(rdb, ctx, "zseta", []redis.Z{
+                       {Score: 1, Member: "a"},
+                       {Score: 2, Member: "b"},
+                       {Score: 3, Member: "c"},
+               })
+               createZset(rdb, ctx, "zsetb", []redis.Z{
+                       {Score: 1, Member: "b"},
+                       {Score: 2, Member: "c"},
+                       {Score: 3, Member: "d"},
+               })
+               require.Equal(t, int64(2), rdb.ZInterCard(ctx, 0, "zseta", 
"zsetb").Val())
+               require.Equal(t, int64(1), rdb.ZInterCard(ctx, 1, "zseta", 
"zsetb").Val())
+
+               require.Error(t, rdb.ZInterCard(ctx, -1, "zseta", 
"zsetb").Err())
+
+       })
+
        for i, cmd := range []func(ctx context.Context, dest string, store 
*redis.ZStore) *redis.IntCmd{rdb.ZInterStore, rdb.ZUnionStore} {
                var funcName string
                switch i {

Reply via email to