This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new ee41959e Add the support of ZINTER and ZINTERCARD (#1992)
ee41959e is described below
commit ee41959e837a5a3023a9eb6aef4914567e9fdd64
Author: kay011 <[email protected]>
AuthorDate: Tue Jan 9 21:40:48 2024 +0800
Add the support of ZINTER and ZINTERCARD (#1992)
---
src/commands/cmd_zset.cc | 83 ++++++++++++++++++++++++++++++++
src/types/redis_zset.cc | 46 ++++++++++++++++++
src/types/redis_zset.h | 1 +
tests/gocase/unit/type/zset/zset_test.go | 52 ++++++++++++++++++++
4 files changed, 182 insertions(+)
diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index ab4909f4..1fa51ab2 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -1252,6 +1252,87 @@ class CommandZInterStore : public CommandZUnionStore {
}
};
+class CommandZInter : public CommandZUnion {
+ public:
+ CommandZInter() : CommandZUnion() {}
+
+ Status Execute(Server *srv, Connection *conn, std::string *output) override {
+ redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+ std::vector<MemberScore> member_scores;
+ auto s = zset_db.Inter(keys_weights_, aggregate_method_, &member_scores);
+ if (!s.ok()) {
+ return {Status::RedisExecErr, s.ToString()};
+ }
+ auto ms_comparator = [](const MemberScore &ms1, const MemberScore &ms2) {
+ if (ms1.score == ms2.score) {
+ return ms1.member < ms2.member;
+ }
+ return ms1.score < ms2.score;
+ };
+ std::sort(member_scores.begin(), member_scores.end(), ms_comparator);
+ output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 :
1)));
+ for (const auto &member_score : member_scores) {
+ output->append(redis::BulkString(member_score.member));
+ if (with_scores_)
output->append(redis::BulkString(util::Float2String(member_score.score)));
+ }
+ return Status::OK();
+ }
+
+ static CommandKeyRange Range(const std::vector<std::string> &args) {
+ int num_key = *ParseInt<int>(args[1], 10);
+ return {2, 1 + num_key, 1};
+ }
+};
+
+class CommandZInterCard : public Commander {
+ public:
+ Status Parse(const std::vector<std::string> &args) override {
+ CommandParser parser(args, 1);
+ numkeys_ = GET_OR_RET(parser.TakeInt<int>(NumericRange<int>{1,
std::numeric_limits<int>::max()}));
+ for (size_t i = 0; i < numkeys_; ++i) {
+ keys_.emplace_back(GET_OR_RET(parser.TakeStr()));
+ }
+
+ // if set limit option
+ if (parser.Good()) {
+ if (parser.EatEqICase("limit")) {
+ auto res = parser.TakeInt<int64_t>();
+ if (!res.IsOK() || res.GetValue() < 0) {
+ return {Status::RedisParseErr, errLimitIsNegative};
+ }
+ limit_ = static_cast<size_t>(res.GetValue());
+ if (parser.Good()) {
+ return parser.InvalidSyntax();
+ }
+ } else {
+ return parser.InvalidSyntax();
+ }
+ }
+
+ return Commander::Parse(args);
+ }
+ Status Execute(Server *srv, Connection *conn, std::string *output) override {
+ redis::ZSet zset_db(srv->storage, conn->GetNamespace());
+ uint64_t count = 0;
+ auto s = zset_db.InterCard(keys_, limit_, &count);
+ if (!s.ok()) {
+ return {Status::RedisExecErr, s.ToString()};
+ }
+ *output = redis::Integer(count);
+ return Status::OK();
+ }
+
+ static CommandKeyRange Range(const std::vector<std::string> &args) {
+ int num_key = *ParseInt<int>(args[1], 10);
+ return {2, 1 + num_key, 1};
+ }
+
+ private:
+ size_t numkeys_ = 0;
+ size_t limit_ = 0;
+ std::vector<std::string> keys_;
+};
+
class CommandZScan : public CommandSubkeyScanBase {
public:
CommandZScan() = default;
@@ -1281,6 +1362,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd",
-4, "write", 1, 1, 1),
MakeCmdAttr<CommandZCount>("zcount", 4, "read-only",
1, 1, 1),
MakeCmdAttr<CommandZIncrBy>("zincrby", 4, "write", 1,
1, 1),
MakeCmdAttr<CommandZInterStore>("zinterstore", -4,
"write", CommandZInterStore::Range),
+ MakeCmdAttr<CommandZInter>("zinter", -3, "read-only",
CommandZInter::Range),
+ MakeCmdAttr<CommandZInterCard>("zintercard", -3,
"read-only", CommandZInterCard::Range),
MakeCmdAttr<CommandZLexCount>("zlexcount", 4,
"read-only", 1, 1, 1),
MakeCmdAttr<CommandZPopMax>("zpopmax", -2, "write", 1,
1, 1),
MakeCmdAttr<CommandZPopMin>("zpopmin", -2, "write", 1,
1, 1),
diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc
index 92bf5be7..f2976544 100644
--- a/src/types/redis_zset.cc
+++ b/src/types/redis_zset.cc
@@ -699,6 +699,52 @@ rocksdb::Status ZSet::Inter(const std::vector<KeyWeight>
&keys_weights, Aggregat
return rocksdb::Status::OK();
}
+rocksdb::Status ZSet::InterCard(const std::vector<std::string> &user_keys,
uint64_t limit, uint64_t *inter_cnt) {
+ std::vector<std::string> lock_keys;
+ lock_keys.reserve(user_keys.size());
+ for (const auto &user_key : user_keys) {
+ std::string ns_key = AppendNamespacePrefix(user_key);
+ lock_keys.emplace_back(std::move(ns_key));
+ }
+ MultiLockGuard guard(storage_->GetLockManager(), lock_keys);
+
+ std::vector<MemberScores> mscores_list;
+ mscores_list.reserve(user_keys.size());
+ RangeScoreSpec spec;
+ for (const auto &user_key : user_keys) {
+ MemberScores mscores;
+ auto s = RangeByScore(user_key, spec, &mscores, nullptr);
+ if (!s.ok() || mscores.empty()) return s;
+ mscores_list.emplace_back(mscores);
+ }
+ std::sort(mscores_list.begin(), mscores_list.end(),
+ [](const MemberScores &v1, const MemberScores &v2) { return
v1.size() < v2.size(); });
+
+ auto base_mscores = mscores_list[0];
+ std::map<std::string, size_t> member_counters;
+ uint64_t cardinality = 0;
+ for (const auto &base_ms : base_mscores) {
+ member_counters[base_ms.member] = 1;
+ for (size_t i = 1; i < mscores_list.size(); i++) {
+ for (const auto &ms : mscores_list[i]) {
+ if (base_ms.member == ms.member) {
+ member_counters[ms.member]++;
+ break;
+ }
+ }
+ }
+ if (member_counters[base_ms.member] == mscores_list.size()) {
+ cardinality++;
+ if (limit > 0 && cardinality >= limit) {
+ *inter_cnt = limit;
+ return rocksdb::Status::OK();
+ };
+ }
+ }
+ *inter_cnt = cardinality;
+ return rocksdb::Status::OK();
+}
+
rocksdb::Status ZSet::UnionStore(const Slice &dst, const
std::vector<KeyWeight> &keys_weights,
AggregateMethod aggregate_method, uint64_t
*saved_cnt) {
*saved_cnt = 0;
diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h
index 9768d010..3cd81622 100644
--- a/src/types/redis_zset.h
+++ b/src/types/redis_zset.h
@@ -111,6 +111,7 @@ class ZSet : public SubKeyScanner {
AggregateMethod aggregate_method, uint64_t
*saved_cnt);
rocksdb::Status Inter(const std::vector<KeyWeight> &keys_weights,
AggregateMethod aggregate_method,
std::vector<MemberScore> *members);
+ rocksdb::Status InterCard(const std::vector<std::string> &user_keys,
uint64_t limit, uint64_t *inter_cnt);
rocksdb::Status UnionStore(const Slice &dst, const std::vector<KeyWeight>
&keys_weights,
AggregateMethod aggregate_method, uint64_t
*saved_cnt);
rocksdb::Status Union(const std::vector<KeyWeight> &keys_weights,
AggregateMethod aggregate_method,
diff --git a/tests/gocase/unit/type/zset/zset_test.go
b/tests/gocase/unit/type/zset/zset_test.go
index d1590425..860316b2 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -1236,6 +1236,58 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx
context.Context, encoding s
require.Equal(t, []redis.Z{{2, "b"}, {3, "c"}},
rdb.ZRangeWithScores(ctx, "zsetc", 0, -1).Val())
})
+ t.Run(fmt.Sprintf("ZINTER with AGGREGATE and WEIGHTS - %s", encoding),
func(t *testing.T) {
+ createZset(rdb, ctx, "zseta", []redis.Z{
+ {Score: 1, Member: "a"},
+ {Score: 2, Member: "b"},
+ {Score: 3, Member: "c"},
+ })
+ createZset(rdb, ctx, "zsetb", []redis.Z{
+ {Score: 1, Member: "b"},
+ {Score: 2, Member: "c"},
+ {Score: 3, Member: "d"},
+ })
+
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}}).Val())
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "max"}).Val())
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "min"}).Val())
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Aggregate: "sum"}).Val())
+
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3},
Aggregate: "sum"}).Val())
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3},
Aggregate: "max"}).Val())
+ require.Equal(t, []string{"b", "c"}, rdb.ZInter(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb"}, Weights: []float64{2, 3},
Aggregate: "min"}).Val())
+
+ require.Equal(t, []redis.Z{{3, "b"}, {5, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta",
"zsetb"}}).Val())
+ require.Equal(t, []redis.Z{{2, "b"}, {3, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Aggregate: "max"}).Val())
+ require.Equal(t, []redis.Z{{1, "b"}, {2, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Aggregate: "min"}).Val())
+ require.Equal(t, []redis.Z{{3, "b"}, {5, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Aggregate: "sum"}).Val())
+
+ require.Equal(t, []redis.Z{{7, "b"}, {12, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Weights: []float64{2, 3}, Aggregate: "sum"}).Val())
+ require.Equal(t, []redis.Z{{4, "b"}, {6, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Weights: []float64{2, 3}, Aggregate: "max"}).Val())
+ require.Equal(t, []redis.Z{{3, "b"}, {6, "c"}},
rdb.ZInterWithScores(ctx, &redis.ZStore{Keys: []string{"zseta", "zsetb"},
Weights: []float64{2, 3}, Aggregate: "min"}).Val())
+
+ require.Equal(t, 0, len(rdb.ZInter(ctx, &redis.ZStore{Keys:
[]string{"zseta", "zsetb", "zset_noexists"}}).Val()))
+ require.Equal(t, 0, len(rdb.ZInterWithScores(ctx,
&redis.ZStore{Keys: []string{"zseta", "zsetb", "zset_noexists"}, Weights:
[]float64{2, 3}, Aggregate: "sum"}).Val()))
+ })
+
+ t.Run(fmt.Sprintf("ZINTERCARD with LIMIT - %s", encoding), func(t
*testing.T) {
+ createZset(rdb, ctx, "zseta", []redis.Z{
+ {Score: 1, Member: "a"},
+ {Score: 2, Member: "b"},
+ {Score: 3, Member: "c"},
+ })
+ createZset(rdb, ctx, "zsetb", []redis.Z{
+ {Score: 1, Member: "b"},
+ {Score: 2, Member: "c"},
+ {Score: 3, Member: "d"},
+ })
+ require.Equal(t, int64(2), rdb.ZInterCard(ctx, 0, "zseta",
"zsetb").Val())
+ require.Equal(t, int64(1), rdb.ZInterCard(ctx, 1, "zseta",
"zsetb").Val())
+
+ require.Error(t, rdb.ZInterCard(ctx, -1, "zseta",
"zsetb").Err())
+
+ })
+
for i, cmd := range []func(ctx context.Context, dest string, store
*redis.ZStore) *redis.IntCmd{rdb.ZInterStore, rdb.ZUnionStore} {
var funcName string
switch i {