This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 3ea7f4fe Add the support of RESP3 double type (#2053)
3ea7f4fe is described below

commit 3ea7f4fece7e5af38d276b3cc5037e53f1d7965f
Author: hulk <[email protected]>
AuthorDate: Fri Jan 26 13:09:53 2024 +0800

    Add the support of RESP3 double type (#2053)
---
 src/commands/cmd_geo.cc                     | 19 ++++------
 src/commands/cmd_hash.cc                    |  9 ++++-
 src/commands/cmd_server.cc                  |  4 +-
 src/commands/cmd_zset.cc                    | 51 +++++++++++++------------
 src/commands/scan_base.h                    | 22 ++---------
 src/server/redis_connection.h               |  3 ++
 tests/gocase/unit/debug/debug_test.go       |  2 +
 tests/gocase/unit/geo/geo_test.go           | 14 ++++++-
 tests/gocase/unit/protocol/protocol_test.go |  2 +
 tests/gocase/unit/type/hash/hash_test.go    | 40 +++++++++----------
 tests/gocase/unit/type/zset/zset_test.go    | 59 +++++++++++++++++++----------
 11 files changed, 125 insertions(+), 100 deletions(-)

diff --git a/src/commands/cmd_geo.cc b/src/commands/cmd_geo.cc
index 0f4d98eb..2342e3df 100644
--- a/src/commands/cmd_geo.cc
+++ b/src/commands/cmd_geo.cc
@@ -150,7 +150,7 @@ class CommandGeoDist : public CommandGeoBase {
     if (s.IsNotFound()) {
       *output = conn->NilString();
     } else {
-      *output = 
redis::BulkString(util::Float2String(GetDistanceByUnit(distance)));
+      *output = conn->Double(GetDistanceByUnit(distance));
     }
     return Status::OK();
   }
@@ -215,8 +215,7 @@ class CommandGeoPos : public Commander {
       if (iter == geo_points.end()) {
         list.emplace_back(conn->NilString());
       } else {
-        list.emplace_back(conn->MultiBulkString(
-            {util::Float2String(iter->second.longitude), 
util::Float2String(iter->second.latitude)}));
+        list.emplace_back(redis::Array({conn->Double(iter->second.longitude), 
conn->Double(iter->second.latitude)}));
       }
     }
     *output = redis::Array(list);
@@ -331,14 +330,13 @@ class CommandGeoRadius : public CommandGeoBase {
         std::vector<std::string> one;
         one.emplace_back(redis::BulkString(geo_point.member));
         if (with_dist_) {
-          
one.emplace_back(redis::BulkString(util::Float2String(GetDistanceByUnit(geo_point.dist))));
+          one.emplace_back(conn->Double(GetDistanceByUnit(geo_point.dist)));
         }
         if (with_hash_) {
-          
one.emplace_back(redis::BulkString(util::Float2String(geo_point.score)));
+          one.emplace_back(conn->Double(geo_point.score));
         }
         if (with_coord_) {
-          one.emplace_back(
-              conn->MultiBulkString({util::Float2String(geo_point.longitude), 
util::Float2String(geo_point.latitude)}));
+          one.emplace_back(redis::Array({conn->Double(geo_point.longitude), 
conn->Double(geo_point.latitude)}));
         }
         list.emplace_back(redis::Array(one));
       }
@@ -509,14 +507,13 @@ class CommandGeoSearch : public CommandGeoBase {
         std::vector<std::string> one;
         one.emplace_back(redis::BulkString(geo_point.member));
         if (with_dist_) {
-          
one.emplace_back(redis::BulkString(util::Float2String(GetDistanceByUnit(geo_point.dist))));
+          one.emplace_back(conn->Double(GetDistanceByUnit(geo_point.dist)));
         }
         if (with_hash_) {
-          
one.emplace_back(redis::BulkString(util::Float2String(geo_point.score)));
+          one.emplace_back(conn->Double(geo_point.score));
         }
         if (with_coord_) {
-          one.emplace_back(
-              conn->MultiBulkString({util::Float2String(geo_point.longitude), 
util::Float2String(geo_point.latitude)}));
+          one.emplace_back(redis::Array({conn->Double(geo_point.longitude), 
conn->Double(geo_point.latitude)}));
         }
         output.emplace_back(redis::Array(one));
       }
diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc
index c62aabd5..cc4c475e 100644
--- a/src/commands/cmd_hash.cc
+++ b/src/commands/cmd_hash.cc
@@ -372,7 +372,14 @@ class CommandHScan : public CommandSubkeyScanBase {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    *output = GenerateOutput(srv, conn, fields, values, CursorType::kTypeHash);
+    auto cursor = GetNextCursor(srv, fields, CursorType::kTypeHash);
+    std::vector<std::string> entries;
+    entries.reserve(2 * fields.size());
+    for (size_t i = 0; i < fields.size(); i++) {
+      entries.emplace_back(redis::BulkString(fields[i]));
+      entries.emplace_back(redis::BulkString(values[i]));
+    }
+    *output = redis::Array({redis::BulkString(cursor), redis::Array(entries)});
     return Status::OK();
   }
 };
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index 45a6aeaf..10cbc8cd 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -607,6 +607,8 @@ class CommandDebug : public Commander {
         *output = redis::BulkString("Hello World");
       } else if (protocol_type_ == "integer") {
         *output = redis::Integer(12345);
+      } else if (protocol_type_ == "double") {
+        *output = conn->Double(3.141);
       } else if (protocol_type_ == "array") {
         *output = redis::MultiLen(3);
         for (int i = 0; i < 3; i++) {
@@ -634,7 +636,7 @@ class CommandDebug : public Commander {
       } else {
         *output = redis::Error(
             "Wrong protocol type name. Please use one of the following: "
-            "string|integer|array|set|bignum|true|false|null");
+            "string|integer|double|array|set|bignum|true|false|null");
       }
     } else {
       return {Status::RedisInvalidCmd, "Unknown subcommand, should be DEBUG or 
PROTOCOL"};
diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc
index 12927fa6..17ec5548 100644
--- a/src/commands/cmd_zset.cc
+++ b/src/commands/cmd_zset.cc
@@ -87,7 +87,7 @@ class CommandZAdd : public Commander {
         return Status::OK();
       }
 
-      *output = redis::BulkString(util::Float2String(new_score));
+      *output = conn->Double(new_score);
     } else {
       *output = redis::Integer(ret);
     }
@@ -192,7 +192,7 @@ class CommandZIncrBy : public Commander {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    *output = redis::BulkString(util::Float2String(score));
+    *output = conn->Double(score);
     return Status::OK();
   }
 
@@ -258,7 +258,7 @@ class CommandZPop : public Commander {
     output->append(redis::MultiLen(member_scores.size() * 2));
     for (const auto &ms : member_scores) {
       output->append(redis::BulkString(ms.member));
-      output->append(redis::BulkString(util::Float2String(ms.score)));
+      output->append(conn->Double(ms.score));
     }
 
     return Status::OK();
@@ -329,7 +329,7 @@ class CommandBZPop : public BlockingCommander {
     }
 
     if (!member_scores.empty()) {
-      SendMembersWithScores(member_scores, user_key);
+      SendMembersWithScores(conn, member_scores, user_key);
       return Status::OK();
     }
 
@@ -350,13 +350,14 @@ class CommandBZPop : public BlockingCommander {
     }
   }
 
-  void SendMembersWithScores(const std::vector<MemberScore> &member_scores, 
const std::string &user_key) {
+  void SendMembersWithScores(const Connection *conn, const 
std::vector<MemberScore> &member_scores,
+                             const std::string &user_key) {
     std::string output;
     output.append(redis::MultiLen(member_scores.size() * 2 + 1));
     output.append(redis::BulkString(user_key));
     for (const auto &ms : member_scores) {
       output.append(redis::BulkString(ms.member));
-      output.append(redis::BulkString(util::Float2String(ms.score)));
+      output.append(conn->Double(ms.score));
     }
     conn_->Reply(output);
   }
@@ -374,7 +375,7 @@ class CommandBZPop : public BlockingCommander {
 
     bool empty = member_scores.empty();
     if (!empty) {
-      SendMembersWithScores(member_scores, user_key);
+      SendMembersWithScores(conn_, member_scores, user_key);
     }
 
     return !empty;
@@ -405,7 +406,7 @@ static void SendMembersWithScoresForZMpop(Connection *conn, 
const std::string &u
   output.append(redis::MultiLen(member_scores.size() * 2));
   for (const auto &ms : member_scores) {
     output.append(redis::BulkString(ms.member));
-    output.append(redis::BulkString(util::Float2String(ms.score)));
+    output.append(conn->Double(ms.score));
   }
   conn->Reply(output);
 }
@@ -817,7 +818,7 @@ class CommandZRangeGeneric : public Commander {
     output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 
1)));
     for (const auto &ms : member_scores) {
       output->append(redis::BulkString(ms.member));
-      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(ms.score)));
+      if (with_scores_) output->append(conn->Double(ms.score));
     }
     return Status::OK();
   }
@@ -904,7 +905,7 @@ class CommandZRank : public Commander {
       if (with_score_) {
         output->append(redis::MultiLen(2));
         output->append(redis::Integer(rank));
-        output->append(redis::BulkString(util::Float2String(score)));
+        output->append(conn->Double(score));
       } else {
         *output = redis::Integer(rank);
       }
@@ -1047,7 +1048,7 @@ class CommandZScore : public Commander {
     if (s.IsNotFound()) {
       *output = conn->NilString();
     } else {
-      *output = redis::BulkString(util::Float2String(score));
+      *output = conn->Double(score);
     }
     return Status::OK();
   }
@@ -1074,9 +1075,9 @@ class CommandZMScore : public Commander {
       for (const auto &member : members) {
         auto iter = mscores.find(member.ToString());
         if (iter == mscores.end()) {
-          values.emplace_back("");
+          values.emplace_back(conn->NilString());
         } else {
-          values.emplace_back(util::Float2String(iter->second));
+          values.emplace_back(conn->Double(iter->second));
         }
       }
     }
@@ -1142,7 +1143,7 @@ class CommandZUnion : public Commander {
     output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 
1)));
     for (const auto &ms : member_scores) {
       output->append(redis::BulkString(ms.member));
-      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(ms.score)));
+      if (with_scores_) output->append(conn->Double(ms.score));
     }
     return Status::OK();
   }
@@ -1276,7 +1277,7 @@ class CommandZInter : public CommandZUnion {
     output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 
1)));
     for (const auto &member_score : member_scores) {
       output->append(redis::BulkString(member_score.member));
-      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(member_score.score)));
+      if (with_scores_) output->append(conn->Double(member_score.score));
     }
     return Status::OK();
   }
@@ -1350,12 +1351,14 @@ class CommandZScan : public CommandSubkeyScanBase {
       return {Status::RedisExecErr, s.ToString()};
     }
 
-    std::vector<std::string> score_strings;
-    score_strings.reserve(scores.size());
-    for (const auto &score : scores) {
-      score_strings.emplace_back(util::Float2String(score));
+    auto cursor = GetNextCursor(srv, members, CursorType::kTypeZSet);
+    std::vector<std::string> entries;
+    entries.reserve(2 * members.size());
+    for (size_t i = 0; i < members.size(); i++) {
+      entries.emplace_back(redis::BulkString(members[i]));
+      entries.emplace_back(conn->Double(scores[i]));
     }
-    *output = GenerateOutput(srv, conn, members, score_strings, 
CursorType::kTypeZSet);
+    *output = redis::Array({redis::BulkString(cursor), redis::Array(entries)});
     return Status::OK();
   }
 };
@@ -1402,14 +1405,14 @@ class CommandZRandMember : public Commander {
     result_entries.reserve(member_scores.size());
 
     for (const auto &[member, score] : member_scores) {
-      result_entries.emplace_back(member);
-      if (with_scores_) result_entries.emplace_back(util::Float2String(score));
+      result_entries.emplace_back(BulkString(member));
+      if (with_scores_) result_entries.emplace_back(conn->Double(score));
     }
 
     if (no_parameters_)
       *output = s.IsNotFound() ? conn->NilString() : 
redis::BulkString(result_entries[0]);
     else
-      *output = ArrayOfBulkStrings(result_entries);
+      *output = Array(result_entries);
     return Status::OK();
   }
 
@@ -1456,7 +1459,7 @@ class CommandZDiff : public Commander {
     output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ 
? 2 : 1)));
     for (const auto &ms : members_with_scores) {
       output->append(redis::BulkString(ms.member));
-      if (with_scores_) 
output->append(redis::BulkString(util::Float2String(ms.score)));
+      if (with_scores_) output->append(conn->Double(ms.score));
     }
 
     return Status::OK();
diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h
index 5fc345c9..2e11c989 100644
--- a/src/commands/scan_base.h
+++ b/src/commands/scan_base.h
@@ -112,25 +112,11 @@ class CommandSubkeyScanBase : public CommandScanBase {
     return Commander::Parse(args);
   }
 
-  std::string GenerateOutput(Server *srv, const Connection *conn, const 
std::vector<std::string> &fields,
-                             const std::vector<std::string> &values, 
CursorType cursor_type) {
-    std::vector<std::string> list;
-    auto items_count = fields.size();
-    if (items_count == static_cast<size_t>(limit_)) {
-      auto end_cursor = srv->GenerateCursorFromKeyName(fields.back(), 
cursor_type);
-      list.emplace_back(redis::BulkString(end_cursor));
-    } else {
-      list.emplace_back(redis::BulkString("0"));
+  std::string GetNextCursor(Server *srv, std::vector<std::string> &fields, 
CursorType cursor_type) const {
+    if (fields.size() == static_cast<size_t>(limit_)) {
+      return srv->GenerateCursorFromKeyName(fields.back(), cursor_type);
     }
-    std::vector<std::string> fvs;
-    if (items_count > 0) {
-      for (size_t i = 0; i < items_count; i++) {
-        fvs.emplace_back(fields[i]);
-        fvs.emplace_back(values[i]);
-      }
-    }
-    list.emplace_back(ArrayOfBulkStrings(fvs));
-    return redis::Array(list);
+    return "0";
   }
 
  protected:
diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h
index d05b2e53..e51f66ab 100644
--- a/src/server/redis_connection.h
+++ b/src/server/redis_connection.h
@@ -68,6 +68,9 @@ class Connection : public EvbufCallbackBase<Connection> {
   std::string BigNumber(const std::string &n) const {
     return protocol_version_ == RESP::v3 ? "(" + n + CRLF : BulkString(n);
   }
+  std::string Double(double d) const {
+    return protocol_version_ == RESP::v3 ? "," + util::Float2String(d) + CRLF 
: BulkString(util::Float2String(d));
+  }
   std::string NilString() const { return redis::NilString(protocol_version_); }
   std::string NilArray() const { return protocol_version_ == RESP::v3 ? "_" 
CRLF : "*-1" CRLF; }
   std::string MultiBulkString(const std::vector<std::string> &values) const;
diff --git a/tests/gocase/unit/debug/debug_test.go 
b/tests/gocase/unit/debug/debug_test.go
index 1aae088f..416e4a7d 100644
--- a/tests/gocase/unit/debug/debug_test.go
+++ b/tests/gocase/unit/debug/debug_test.go
@@ -44,6 +44,7 @@ func TestDebugProtocolV2(t *testing.T) {
                types := map[string]interface{}{
                        "string":  "Hello World",
                        "integer": int64(12345),
+                       "double":  "3.141",
                        "array":   []interface{}{int64(0), int64(1), int64(2)},
                        "set":     []interface{}{int64(0), int64(1), int64(2)},
                        "map":     []interface{}{int64(0), int64(0), int64(1), 
int64(1), int64(2), int64(0)},
@@ -89,6 +90,7 @@ func TestDebugProtocolV3(t *testing.T) {
                types := map[string]interface{}{
                        "string":  "Hello World",
                        "integer": int64(12345),
+                       "double":  3.141,
                        "array":   []interface{}{int64(0), int64(1), int64(2)},
                        "set":     []interface{}{int64(0), int64(1), int64(2)},
                        "map":     map[interface{}]interface{}{int64(0): false, 
int64(1): true, int64(2): false},
diff --git a/tests/gocase/unit/geo/geo_test.go 
b/tests/gocase/unit/geo/geo_test.go
index 6db8c222..ebbbcd32 100644
--- a/tests/gocase/unit/geo/geo_test.go
+++ b/tests/gocase/unit/geo/geo_test.go
@@ -86,8 +86,18 @@ func compareLists(list1, list2 []string) []string {
        return result
 }
 
-func TestGeo(t *testing.T) {
-       srv := util.StartServer(t, map[string]string{})
+func TestGeoWithRESP2(t *testing.T) {
+       testGeo(t, "no")
+}
+
+func TestGeoWithRESP3(t *testing.T) {
+       testGeo(t, "yes")
+}
+
+var testGeo = func(t *testing.T, enabledRESP3 string) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": enabledRESP3,
+       })
        defer srv.Close()
        ctx := context.Background()
        rdb := srv.NewClient()
diff --git a/tests/gocase/unit/protocol/protocol_test.go 
b/tests/gocase/unit/protocol/protocol_test.go
index 33202c90..61db7cf1 100644
--- a/tests/gocase/unit/protocol/protocol_test.go
+++ b/tests/gocase/unit/protocol/protocol_test.go
@@ -153,6 +153,7 @@ func TestProtocolRESP2(t *testing.T) {
                types := map[string][]string{
                        "string":  {"$11", "Hello World"},
                        "integer": {":12345"},
+                       "double":  {"$5", "3.141"},
                        "array":   {"*3", ":0", ":1", ":2"},
                        "set":     {"*3", ":0", ":1", ":2"},
                        "map":     {"*6", ":0", ":0", ":1", ":1", ":2", ":0"},
@@ -208,6 +209,7 @@ func TestProtocolRESP3(t *testing.T) {
                types := map[string][]string{
                        "string":  {"$11", "Hello World"},
                        "integer": {":12345"},
+                       "double":  {",3.141"},
                        "array":   {"*3", ":0", ":1", ":2"},
                        "set":     {"~3", ":0", ":1", ":2"},
                        "map":     {"%3", ":0", "#f", ":1", "#t", ":2", "#f"},
diff --git a/tests/gocase/unit/type/hash/hash_test.go 
b/tests/gocase/unit/type/hash/hash_test.go
index e6c8beba..38b0a576 100644
--- a/tests/gocase/unit/type/hash/hash_test.go
+++ b/tests/gocase/unit/type/hash/hash_test.go
@@ -50,8 +50,18 @@ func getVals(hash map[string]string) []string {
        return r
 }
 
-func TestHash(t *testing.T) {
-       srv := util.StartServer(t, map[string]string{})
+func TestHashWithRESP2(t *testing.T) {
+       testHash(t, "no")
+}
+
+func TestHashWithRESP3(t *testing.T) {
+       testHash(t, "yes")
+}
+
+var testHash = func(t *testing.T, enabledRESP3 string) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": enabledRESP3,
+       })
        defer srv.Close()
        ctx := context.Background()
        rdb := srv.NewClient()
@@ -359,29 +369,15 @@ func TestHash(t *testing.T) {
        })
 
        t.Run("HGETALL - small hash}", func(t *testing.T) {
-               res := rdb.Do(ctx, "hgetall", "smallhash").Val().([]interface{})
-               mid := make(map[string]string)
-               for i := 0; i < len(res); i += 2 {
-                       if res[i+1] == nil {
-                               mid[res[i].(string)] = ""
-                       } else {
-                               mid[res[i].(string)] = res[i+1].(string)
-                       }
-               }
-               require.Equal(t, smallhash, mid)
+               gotHash, err := rdb.HGetAll(ctx, "smallhash").Result()
+               require.NoError(t, err)
+               require.Equal(t, smallhash, gotHash)
        })
 
        t.Run("HGETALL - big hash}", func(t *testing.T) {
-               res := rdb.Do(ctx, "hgetall", "bighash").Val().([]interface{})
-               mid := make(map[string]string)
-               for i := 0; i < len(res); i += 2 {
-                       if res[i+1] == nil {
-                               mid[res[i].(string)] = ""
-                       } else {
-                               mid[res[i].(string)] = res[i+1].(string)
-                       }
-               }
-               require.Equal(t, bighash, mid)
+               gotHash, err := rdb.HGetAll(ctx, "bighash").Result()
+               require.NoError(t, err)
+               require.Equal(t, bighash, gotHash)
        })
 
        t.Run("HGETALL - field with empty string as a value", func(t 
*testing.T) {
diff --git a/tests/gocase/unit/type/zset/zset_test.go 
b/tests/gocase/unit/type/zset/zset_test.go
index d7bc434e..5f1cf80f 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -67,7 +67,8 @@ func createDefaultLexZset(rdb *redis.Client, ctx 
context.Context) {
                {0, "omega"}})
 }
 
-func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding 
string, srv *util.KvrocksServer) {
+func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, 
enabledRESP3, encoding string, srv *util.KvrocksServer) {
+       isRESP3 := enabledRESP3 == "yes"
        t.Run(fmt.Sprintf("Check encoding - %s", encoding), func(t *testing.T) {
                rdb.Del(ctx, "ztmp")
                rdb.ZAdd(ctx, "ztmp", redis.Z{Score: 10, Member: "x"})
@@ -103,9 +104,15 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
 
        t.Run(fmt.Sprintf("ZSET ZADD IncrMixedOtherOptions - %s", encoding), 
func(t *testing.T) {
                rdb.Del(ctx, "ztmp")
-               require.Equal(t, "1.5", rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", 
"nx", "nx", "incr", "1.5", "abc").Val())
-               require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", "nx", 
"nx", "nx", "nx", "incr", "1.5", "abc").Err())
-               require.Equal(t, "3", rdb.Do(ctx, "zadd", "ztmp", "xx", "xx", 
"xx", "xx", "incr", "1.5", "abc").Val())
+               if isRESP3 {
+                       require.Equal(t, 1.5, rdb.Do(ctx, "zadd", "ztmp", "nx", 
"nx", "nx", "nx", "incr", "1.5", "abc").Val())
+                       require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", 
"nx", "nx", "nx", "nx", "incr", "1.5", "abc").Err())
+                       require.EqualValues(t, 3, rdb.Do(ctx, "zadd", "ztmp", 
"xx", "xx", "xx", "xx", "incr", "1.5", "abc").Val())
+               } else {
+                       require.Equal(t, "1.5", rdb.Do(ctx, "zadd", "ztmp", 
"nx", "nx", "nx", "nx", "incr", "1.5", "abc").Val())
+                       require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", 
"nx", "nx", "nx", "nx", "incr", "1.5", "abc").Err())
+                       require.Equal(t, "3", rdb.Do(ctx, "zadd", "ztmp", "xx", 
"xx", "xx", "xx", "incr", "1.5", "abc").Val())
+               }
 
                rdb.Del(ctx, "ztmp")
                require.Equal(t, 1.5, rdb.ZAddArgsIncr(ctx, "ztmp", 
redis.ZAddArgs{NX: true, Members: []redis.Z{{Member: "abc", Score: 
1.5}}}).Val())
@@ -684,14 +691,14 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
                require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", 
"z").Val())
                require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", 
"foo").Err())
 
-               require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, 
"zrank", "zranktmp", "x", "withscore").Val())
-               require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, 
"zrank", "zranktmp", "y", "withscore").Val())
-               require.Equal(t, []interface{}{int64(2), "30"}, rdb.Do(ctx, 
"zrank", "zranktmp", "z", "withscore").Val())
-               require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", 
"foo", "withscore").Err())
-               require.Equal(t, []interface{}{int64(2), "10"}, rdb.Do(ctx, 
"zrevrank", "zranktmp", "x", "withscore").Val())
-               require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, 
"zrevrank", "zranktmp", "y", "withscore").Val())
-               require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, 
"zrevrank", "zranktmp", "z", "withscore").Val())
-               require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", 
"foo", "withscore").Err())
+               require.Equal(t, redis.RankScore{Rank: 0, Score: 10}, 
rdb.ZRankWithScore(ctx, "zranktmp", "x").Val())
+               require.Equal(t, redis.RankScore{Rank: 1, Score: 20}, 
rdb.ZRankWithScore(ctx, "zranktmp", "y").Val())
+               require.Equal(t, redis.RankScore{Rank: 2, Score: 30}, 
rdb.ZRankWithScore(ctx, "zranktmp", "z").Val())
+               require.Equal(t, redis.Nil, rdb.ZRankWithScore(ctx, "zranktmp", 
"foo").Err())
+               require.Equal(t, redis.RankScore{Rank: 2, Score: 10}, 
rdb.ZRevRankWithScore(ctx, "zranktmp", "x").Val())
+               require.Equal(t, redis.RankScore{Rank: 1, Score: 20}, 
rdb.ZRevRankWithScore(ctx, "zranktmp", "y").Val())
+               require.Equal(t, redis.RankScore{Rank: 0, Score: 30}, 
rdb.ZRevRankWithScore(ctx, "zranktmp", "z").Val())
+               require.Equal(t, redis.Nil, rdb.ZRevRankWithScore(ctx, 
"zranktmp", "foo").Err())
        })
 
        t.Run(fmt.Sprintf("ZRANK/ZREVRANK - after deletion -%s", encoding), 
func(t *testing.T) {
@@ -704,12 +711,12 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding s
                require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", 
"z").Val())
                require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", 
"foo").Err())
 
-               require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, 
"zrank", "zranktmp", "x", "withscore").Val())
-               require.Equal(t, []interface{}{int64(1), "30"}, rdb.Do(ctx, 
"zrank", "zranktmp", "z", "withscore").Val())
-               require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", 
"foo", "withscore").Err())
-               require.Equal(t, []interface{}{int64(1), "10"}, rdb.Do(ctx, 
"zrevrank", "zranktmp", "x", "withscore").Val())
-               require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, 
"zrevrank", "zranktmp", "z", "withscore").Val())
-               require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", 
"foo", "withscore").Err())
+               require.Equal(t, redis.RankScore{Rank: 0, Score: 10}, 
rdb.ZRankWithScore(ctx, "zranktmp", "x").Val())
+               require.Equal(t, redis.RankScore{Rank: 1, Score: 30}, 
rdb.ZRankWithScore(ctx, "zranktmp", "z").Val())
+               require.Equal(t, redis.Nil, rdb.ZRankWithScore(ctx, "zranktmp", 
"foo").Err())
+               require.Equal(t, redis.RankScore{Rank: 1, Score: 10}, 
rdb.ZRevRankWithScore(ctx, "zranktmp", "x").Val())
+               require.Equal(t, redis.RankScore{Rank: 0, Score: 30}, 
rdb.ZRevRankWithScore(ctx, "zranktmp", "z").Val())
+               require.Equal(t, redis.Nil, rdb.ZRevRankWithScore(ctx, 
"zranktmp", "foo").Err())
        })
 
        t.Run(fmt.Sprintf("ZINCRBY - can create a new sorted set - %s", 
encoding), func(t *testing.T) {
@@ -1893,14 +1900,24 @@ func stressTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding
        })
 }
 
-func TestZset(t *testing.T) {
-       srv := util.StartServer(t, map[string]string{})
+func TestZSetWithRESP2(t *testing.T) {
+       testZSet(t, "no")
+}
+
+func TestZSetWithRESP3(t *testing.T) {
+       testZSet(t, "yes")
+}
+
+var testZSet = func(t *testing.T, enabledRESP3 string) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": enabledRESP3,
+       })
        defer srv.Close()
        ctx := context.Background()
        rdb := srv.NewClient()
        defer func() { require.NoError(t, rdb.Close()) }()
 
-       basicTests(t, rdb, ctx, "skiplist", srv)
+       basicTests(t, rdb, ctx, enabledRESP3, "skiplist", srv)
 
        t.Run("ZUNIONSTORE regression, should not create NaN in scores", func(t 
*testing.T) {
                rdb.ZAdd(ctx, "z", redis.Z{Score: math.Inf(-1), Member: 
"neginf"})

Reply via email to