This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 6e79fa94  Implement the RESP3 boolean type (#1991)
6e79fa94 is described below

commit 6e79fa940edca2ce80d863d92f967d5ade767c6d
Author: hulk <[email protected]>
AuthorDate: Mon Jan 8 10:10:10 2024 +0800

     Implement the RESP3 boolean type (#1991)
    
    Except for the RESP3 boolean type, this PR also implements the
    `DEBUG PROTOCOL <type>`  to test the new types. Also, we have
    added a new configuration `resp3-enabled` to enable the RESP3
    when testing.
---
 kvrocks.conf                                       |   6 ++
 src/commands/cmd_server.cc                         |  40 +++++++-
 src/config/config.cc                               |   1 +
 src/config/config.h                                |   1 +
 src/server/redis_connection.h                      |   6 +-
 src/server/redis_reply.h                           |   9 ++
 src/storage/scripting.cc                           |  14 ++-
 src/storage/scripting.h                            |   2 +-
 tests/gocase/go.mod                                |   2 +-
 tests/gocase/go.sum                                |   8 +-
 .../integration/slotmigrate/slotmigrate_test.go    |   2 +-
 tests/gocase/unit/debug/debug_test.go              | 106 +++++++++++++++++++++
 tests/gocase/unit/hello/hello_test.go              |  21 ++++
 tests/gocase/unit/type/zset/zset_test.go           |   4 +-
 14 files changed, 202 insertions(+), 20 deletions(-)

diff --git a/kvrocks.conf b/kvrocks.conf
index e5460187..7d099073 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -311,6 +311,12 @@ max-bitmap-to-string-mb 16
 # Default: no
 redis-cursor-compatible no
 
+# Whether to enable the RESP3 protocol.
+# NOTICE: RESP3 is still under development, don't enable it in production 
environment.
+#
+# Default: no
+# resp3-enabled no
+
 # Maximum nesting depth allowed when parsing and serializing 
 # JSON documents while using JSON commands like JSON.SET.
 # Default: 1024
diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc
index c06f13f4..a41094d1 100644
--- a/src/commands/cmd_server.cc
+++ b/src/commands/cmd_server.cc
@@ -591,20 +591,44 @@ class CommandDebug : public Commander {
 
       microsecond_ = static_cast<uint64_t>(*second * 1000 * 1000);
       return Status::OK();
+    } else if (subcommand_ == "protocol" && args.size() == 3) {
+      protocol_type_ = util::ToLower(args[2]);
+      return Status::OK();
     }
-    return {Status::RedisInvalidCmd, "Syntax error, DEBUG SLEEP <seconds>"};
+    return {Status::RedisInvalidCmd, "Syntax error, DEBUG SLEEP 
<seconds>|PROTOCOL <type>"};
   }
 
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     if (subcommand_ == "sleep") {
       usleep(microsecond_);
+      *output = redis::SimpleString("OK");
+    } else if (subcommand_ == "protocol") {  // protocol type
+      if (protocol_type_ == "string") {
+        *output = redis::BulkString("Hello World");
+      } else if (protocol_type_ == "integer") {
+        *output = redis::Integer(12345);
+      } else if (protocol_type_ == "array") {
+        *output = redis::MultiLen(3);
+        for (int i = 0; i < 3; i++) {
+          *output += redis::Integer(i);
+        }
+      } else if (protocol_type_ == "true") {
+        *output = redis::Bool(conn->GetProtocolVersion(), true);
+      } else if (protocol_type_ == "false") {
+        *output = redis::Bool(conn->GetProtocolVersion(), false);
+      } else {
+        *output =
+            redis::Error("Wrong protocol type name. Please use one of the 
following: string|int|array|true|false");
+      }
+    } else {
+      return {Status::RedisInvalidCmd, "Unknown subcommand, should be DEBUG or 
PROTOCOL"};
     }
-    *output = redis::SimpleString("OK");
     return Status::OK();
   }
 
  private:
   std::string subcommand_;
+  std::string protocol_type_;
   uint64_t microsecond_ = 0;
 };
 
@@ -685,14 +709,15 @@ class CommandHello final : public Commander {
  public:
   Status Execute(Server *srv, Connection *conn, std::string *output) override {
     size_t next_arg = 1;
+    int protocol = 2;  // default protocol version is 2
     if (args_.size() >= 2) {
-      auto parse_result = ParseInt<int64_t>(args_[next_arg], 10);
+      auto parse_result = ParseInt<int>(args_[next_arg], 10);
       ++next_arg;
       if (!parse_result) {
         return {Status::NotOK, "Protocol version is not an integer or out of 
range"};
       }
 
-      int64_t protocol = *parse_result;
+      protocol = *parse_result;
 
       // In redis, it will check protocol < 2 or protocol > 3,
       // kvrocks only supports REPL2 by now, but for supporting some
@@ -737,7 +762,12 @@ class CommandHello final : public Commander {
     output_list.push_back(redis::BulkString("server"));
     output_list.push_back(redis::BulkString("redis"));
     output_list.push_back(redis::BulkString("proto"));
-    output_list.push_back(redis::Integer(2));
+    if (srv->GetConfig()->resp3_enabled) {
+      output_list.push_back(redis::Integer(protocol));
+      conn->SetProtocolVersion(protocol == 3 ? RESP::v3 : RESP::v2);
+    } else {
+      output_list.push_back(redis::Integer(2));
+    }
 
     output_list.push_back(redis::BulkString("mode"));
     // Note: sentinel is not supported in kvrocks.
diff --git a/src/config/config.cc b/src/config/config.cc
index c71103b7..5ea8c617 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -164,6 +164,7 @@ Config::Config() {
       {"log-retention-days", false, new IntField(&log_retention_days, -1, -1, 
INT_MAX)},
       {"persist-cluster-nodes-enabled", false, new 
YesNoField(&persist_cluster_nodes_enabled, true)},
       {"redis-cursor-compatible", false, new 
YesNoField(&redis_cursor_compatible, false)},
+      {"resp3-enabled", false, new YesNoField(&resp3_enabled, false)},
       {"repl-namespace-enabled", false, new 
YesNoField(&repl_namespace_enabled, false)},
       {"json-max-nesting-depth", false, new IntField(&json_max_nesting_depth, 
1024, 0, INT_MAX)},
       {"json-storage-format", false,
diff --git a/src/config/config.h b/src/config/config.h
index 244f1469..46e260bc 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -147,6 +147,7 @@ struct Config {
   int sequence_gap;
 
   bool redis_cursor_compatible = false;
+  bool resp3_enabled = false;
   int log_retention_days;
 
   // load_tokens is used to buffer the tokens when loading,
diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h
index e38a70fa..25b522d8 100644
--- a/src/server/redis_connection.h
+++ b/src/server/redis_connection.h
@@ -58,10 +58,13 @@ class Connection : public EvbufCallbackBase<Connection> {
   void OnRead(bufferevent *bev);
   void OnWrite(bufferevent *bev);
   void OnEvent(bufferevent *bev, int16_t events);
-  void Reply(const std::string &msg);
   void SendFile(int fd);
   std::string ToString();
 
+  void Reply(const std::string &msg);
+  RESP GetProtocolVersion() const { return protocol_version_; }
+  void SetProtocolVersion(RESP version) { protocol_version_ = version; }
+
   using UnsubscribeCallback = std::function<void(std::string, int)>;
   void SubscribeChannel(const std::string &channel);
   void UnsubscribeChannel(const std::string &channel);
@@ -164,6 +167,7 @@ class Connection : public EvbufCallbackBase<Connection> {
   std::deque<redis::CommandTokens> multi_cmds_;
 
   bool importing_ = false;
+  RESP protocol_version_ = RESP::v2;
 };
 
 }  // namespace redis
diff --git a/src/server/redis_reply.h b/src/server/redis_reply.h
index c3cc1b44..e23380cc 100644
--- a/src/server/redis_reply.h
+++ b/src/server/redis_reply.h
@@ -31,6 +31,8 @@
 
 namespace redis {
 
+enum class RESP { v2, v3 };
+
 void Reply(evbuffer *output, const std::string &data);
 std::string SimpleString(const std::string &data);
 std::string Error(const std::string &err);
@@ -40,6 +42,13 @@ std::string Integer(T data) {
   return ":" + std::to_string(data) + CRLF;
 }
 
+inline std::string Bool(const RESP ver, const bool b) {
+  if (ver == RESP::v3) {
+    return b ? "#t" CRLF : "#f" CRLF;
+  }
+  return Integer(b ? 1 : 0);
+}
+
 std::string BulkString(const std::string &data);
 std::string NilString();
 
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index 3f1bcdde..b8fbc02a 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -417,7 +417,7 @@ Status FunctionCall(redis::Connection *conn, const 
std::string &name, const std:
     lua_pop(lua, 2);
     return {Status::NotOK, fmt::format("Error while running function `{}`: 
{}", name, err_msg)};
   } else {
-    *output = ReplyToRedisReply(lua);
+    *output = ReplyToRedisReply(conn, lua);
     lua_pop(lua, 2);
   }
 
@@ -628,7 +628,7 @@ Status EvalGenericCommand(redis::Connection *conn, const 
std::string &body_or_sh
     *output = redis::Error(msg);
     lua_pop(lua, 2);
   } else {
-    *output = ReplyToRedisReply(lua);
+    *output = ReplyToRedisReply(conn, lua);
     lua_pop(lua, 2);
   }
 
@@ -1073,7 +1073,7 @@ void PushError(lua_State *lua, const char *err) {
 }
 
 // this function does not pop any element on the stack
-std::string ReplyToRedisReply(lua_State *lua) {
+std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) {
   std::string output;
   const char *obj_s = nullptr;
   size_t obj_len = 0;
@@ -1085,7 +1085,11 @@ std::string ReplyToRedisReply(lua_State *lua) {
       output = redis::BulkString(std::string(obj_s, obj_len));
       break;
     case LUA_TBOOLEAN:
-      output = lua_toboolean(lua, -1) ? redis::Integer(1) : redis::NilString();
+      if (conn->GetProtocolVersion() == redis::RESP::v2) {
+        output = lua_toboolean(lua, -1) ? redis::Integer(1) : 
redis::NilString();
+      } else {
+        output = redis::Bool(redis::RESP::v3, lua_toboolean(lua, -1));
+      }
       break;
     case LUA_TNUMBER:
       output = redis::Integer((int64_t)(lua_tonumber(lua, -1)));
@@ -1127,7 +1131,7 @@ std::string ReplyToRedisReply(lua_State *lua) {
             break;
           }
           mbulklen++;
-          output += ReplyToRedisReply(lua);
+          output += ReplyToRedisReply(conn, lua);
           lua_pop(lua, 1);
         }
         output = redis::MultiLen(mbulklen) + output;
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 9e8c45f2..0d9ce46c 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -83,7 +83,7 @@ const char *RedisProtocolToLuaTypeNull(lua_State *lua, const 
char *reply);
 const char *RedisProtocolToLuaTypeBool(lua_State *lua, const char *reply, int 
tf);
 const char *RedisProtocolToLuaTypeDouble(lua_State *lua, const char *reply);
 
-std::string ReplyToRedisReply(lua_State *lua);
+std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua);
 
 void PushError(lua_State *lua, const char *err);
 [[noreturn]] int RaiseError(lua_State *lua);
diff --git a/tests/gocase/go.mod b/tests/gocase/go.mod
index 48f0cf97..298d8363 100644
--- a/tests/gocase/go.mod
+++ b/tests/gocase/go.mod
@@ -3,7 +3,7 @@ module github.com/apache/kvrocks/tests/gocase
 go 1.19
 
 require (
-       github.com/redis/go-redis/v9 v9.0.4
+       github.com/redis/go-redis/v9 v9.3.1
        github.com/shirou/gopsutil/v3 v3.22.9
        github.com/stretchr/testify v1.8.0
        golang.org/x/exp v0.0.0-20220929160808-de9c53c655b9
diff --git a/tests/gocase/go.sum b/tests/gocase/go.sum
index eda0deae..a95eb8ce 100644
--- a/tests/gocase/go.sum
+++ b/tests/gocase/go.sum
@@ -1,5 +1,5 @@
-github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao=
-github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
+github.com/bsm/ginkgo/v2 v2.12.0 
h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
+github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
 github.com/cespare/xxhash/v2 v2.2.0 
h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
 github.com/cespare/xxhash/v2 v2.2.0/go.mod 
h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/davecgh/go-spew v1.1.0/go.mod 
h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -20,8 +20,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod 
h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
 github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod 
h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
 github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b 
h1:0LFwY6Q3gMACTjAbMZBjXAqTOzOwFaj2Ld6cjeQ7Rig=
 github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b/go.mod 
h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
-github.com/redis/go-redis/v9 v9.0.4 
h1:FC82T+CHJ/Q/PdyLW++GeCO+Ol59Y4T7R4jbgjvktgc=
-github.com/redis/go-redis/v9 v9.0.4/go.mod 
h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
+github.com/redis/go-redis/v9 v9.3.1 
h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
+github.com/redis/go-redis/v9 v9.3.1/go.mod 
h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
 github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod 
h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec 
h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod 
h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
diff --git a/tests/gocase/integration/slotmigrate/slotmigrate_test.go 
b/tests/gocase/integration/slotmigrate/slotmigrate_test.go
index b64a841f..afae86d1 100644
--- a/tests/gocase/integration/slotmigrate/slotmigrate_test.go
+++ b/tests/gocase/integration/slotmigrate/slotmigrate_test.go
@@ -562,7 +562,7 @@ func TestSlotMigrateDataType(t *testing.T) {
                require.NoError(t, rdb0.SRem(ctx, keys["set"], 1, 3).Err())
                require.NoError(t, rdb0.Expire(ctx, keys["set"], 
10*time.Second).Err())
                // type zset
-               require.NoError(t, rdb0.ZAdd(ctx, keys["zset"], []redis.Z{{0, 
1}, {2, 3}, {4, 5}}...).Err())
+               require.NoError(t, rdb0.ZAdd(ctx, keys["zset"], []redis.Z{{0, 
"1"}, {2, "3"}, {4, "5"}}...).Err())
                require.NoError(t, rdb0.ZRem(ctx, keys["zset"], 1, 3).Err())
                require.NoError(t, rdb0.Expire(ctx, keys["zset"], 
10*time.Second).Err())
                // type bitmap
diff --git a/tests/gocase/unit/debug/debug_test.go 
b/tests/gocase/unit/debug/debug_test.go
new file mode 100644
index 00000000..7221d830
--- /dev/null
+++ b/tests/gocase/unit/debug/debug_test.go
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package debug
+
+import (
+       "context"
+       "testing"
+
+       "github.com/redis/go-redis/v9"
+
+       "github.com/apache/kvrocks/tests/gocase/util"
+       "github.com/stretchr/testify/require"
+)
+
+func TestDebugProtocolV2(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": "no",
+       })
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       t.Run("debug protocol type", func(t *testing.T) {
+               types := map[string]interface{}{
+                       "string":  "Hello World",
+                       "integer": int64(12345),
+                       "array":   []interface{}{int64(0), int64(1), int64(2)},
+                       "true":    int64(1),
+                       "false":   int64(0),
+               }
+               for typ, expectedValue := range types {
+                       r := rdb.Do(ctx, "DEBUG", "PROTOCOL", typ)
+                       require.NoError(t, r.Err())
+                       require.EqualValues(t, expectedValue, r.Val())
+               }
+       })
+
+       t.Run("lua script return value type", func(t *testing.T) {
+               var returnValueScript = redis.NewScript(`
+                       return ARGV[1]
+               `)
+               val, err := returnValueScript.Run(ctx, rdb, []string{}, 
true).Int()
+               require.NoError(t, err)
+               require.EqualValues(t, 1, val)
+               val, err = returnValueScript.Run(ctx, rdb, []string{}, 
false).Int()
+               require.NoError(t, err)
+               require.EqualValues(t, 0, val)
+       })
+}
+
+func TestDebugProtocolV3(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": "yes",
+       })
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       t.Run("debug protocol type", func(t *testing.T) {
+               types := map[string]interface{}{
+                       "string":  "Hello World",
+                       "integer": int64(12345),
+                       "array":   []interface{}{int64(0), int64(1), int64(2)},
+                       "true":    true,
+                       "false":   false,
+               }
+               for typ, expectedValue := range types {
+                       r := rdb.Do(ctx, "DEBUG", "PROTOCOL", typ)
+                       require.NoError(t, r.Err())
+                       require.EqualValues(t, expectedValue, r.Val())
+               }
+       })
+
+       t.Run("lua script return value type", func(t *testing.T) {
+               var returnValueScript = redis.NewScript(`
+                       return ARGV[1]
+               `)
+               val, err := returnValueScript.Run(ctx, rdb, []string{}, 
true).Bool()
+               require.NoError(t, err)
+               require.EqualValues(t, true, val)
+               val, err = returnValueScript.Run(ctx, rdb, []string{}, 
false).Bool()
+               require.NoError(t, err)
+               require.EqualValues(t, false, val)
+       })
+}
diff --git a/tests/gocase/unit/hello/hello_test.go 
b/tests/gocase/unit/hello/hello_test.go
index 984a11c6..d965b29c 100644
--- a/tests/gocase/unit/hello/hello_test.go
+++ b/tests/gocase/unit/hello/hello_test.go
@@ -76,6 +76,27 @@ func TestHello(t *testing.T) {
        })
 }
 
+func TestEnableRESP3(t *testing.T) {
+       srv := util.StartServer(t, map[string]string{
+               "resp3-enabled": "yes",
+       })
+       defer srv.Close()
+
+       ctx := context.Background()
+       rdb := srv.NewClient()
+       defer func() { require.NoError(t, rdb.Close()) }()
+
+       r := rdb.Do(ctx, "HELLO", "2")
+       rList := r.Val().([]interface{})
+       require.EqualValues(t, rList[2], "proto")
+       require.EqualValues(t, rList[3], 2)
+
+       r = rdb.Do(ctx, "HELLO", "3")
+       rList = r.Val().([]interface{})
+       require.EqualValues(t, rList[2], "proto")
+       require.EqualValues(t, rList[3], 3)
+}
+
 func TestHelloWithAuth(t *testing.T) {
        srv := util.StartServer(t, map[string]string{
                "requirepass": "foobar",
diff --git a/tests/gocase/unit/type/zset/zset_test.go 
b/tests/gocase/unit/type/zset/zset_test.go
index d5822221..d1590425 100644
--- a/tests/gocase/unit/type/zset/zset_test.go
+++ b/tests/gocase/unit/type/zset/zset_test.go
@@ -1340,7 +1340,7 @@ func stressTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding
                                } else if auxList[i].Score > auxList[j].Score {
                                        return false
                                } else {
-                                       if 
strings.Compare(auxList[i].Member.(string), auxList[j].Member.(string)) == 1 {
+                                       if strings.Compare(auxList[i].Member, 
auxList[j].Member) == 1 {
                                                return false
                                        } else {
                                                return true
@@ -1349,7 +1349,7 @@ func stressTests(t *testing.T, rdb *redis.Client, ctx 
context.Context, encoding
                        })
                        var aux []string
                        for _, z := range auxList {
-                               aux = append(aux, z.Member.(string))
+                               aux = append(aux, z.Member)
                        }
                        fromRedis := rdb.ZRange(ctx, "myzset", 0, -1).Val()
                        for i := 0; i < len(fromRedis); i++ {

Reply via email to