This is an automated email from the ASF dual-hosted git repository.

caipengbo pushed a commit to branch 2.8
in repository https://gitbox.apache.org/repos/asf/kvrocks.git

commit ca5bac80e3322b61aec937784534a6bb380535f0
Author: hulk <[email protected]>
AuthorDate: Fri Mar 1 20:06:15 2024 +0800

    Add support of lua function 'redis.setresp()' (#2130)
    
    Except for the `redis.setresp` function, it fixes the different behavior 
with Redis in Lua script.
    
    Before applying this PR, the Lua script will return the result in RESP3 
format
    if the connection is connected with `HELLO 3` command. But for Redis,
    it will always use RESP2 unless users explicitly set it with 
`redis.setresp(3)`.
    
    ```
    // Kvrocks
    
    ❯ redis-cli -3 -p 6666
    
    127.0.0.1:6666> EVAL 'return redis.call("hgetall","hash")' 0
    1# "f1" => "v1"
    
    // Redis
    
    ❯ redis-cli -3
    127.0.0.1:6379> EVAL 'return redis.call("hgetall","hash")' 0
    1) "f1"
    2) "v1"
    ```
    
    After applying this PR, it will behaves consistently with Redis:
    
    ```
    ❯ redis-cli -3 -p 6666
    127.0.0.1:6666> EVAL 'redis.setresp(3);return redis.call("hgetall","hash")' 0
    1# "f1" => "v1"
    127.0.0.1:6666> EVAL 'return redis.call("hgetall","hash")' 0
    1) "f1"
    2) "v1"
    ```
    
    
    Co-authored-by: Binbin <[email protected]>
---
 src/storage/scripting.cc                      | 34 +++++++++++++++++++++
 src/storage/scripting.h                       |  1 +
 tests/gocase/unit/scripting/scripting_test.go | 43 +++++++++++++++++++++++----
 3 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index cd024940..65d8179d 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -87,6 +87,11 @@ void LoadFuncs(lua_State *lua, bool read_only) {
   lua_pushcfunction(lua, RedisPCallCommand);
   lua_settable(lua, -3);
 
+  /* redis.setresp */
+  lua_pushstring(lua, "setresp");
+  lua_pushcfunction(lua, RedisSetResp);
+  lua_settable(lua, -3);
+
   /* redis.log and log levels. */
   lua_pushstring(lua, "log");
   lua_pushcfunction(lua, RedisLogCommand);
@@ -621,6 +626,12 @@ Status EvalGenericCommand(redis::Connection *conn, const 
std::string &body_or_sh
     lua_getglobal(lua, funcname);
   }
 
+  // For the Lua script, should be always run with RESP2 protocol,
+  // unless users explicitly set the protocol version in the script via 
`redis.setresp`.
+  // So we need to save the current protocol version and set it to RESP2,
+  // and then restore it after the script execution.
+  auto saved_protocol_version = conn->GetProtocolVersion();
+  conn->SetProtocolVersion(redis::RESP::v2);
   /* Populate the argv and keys table accordingly to the arguments that
    * EVAL received. */
   SetGlobalArray(lua, "KEYS", keys);
@@ -634,6 +645,7 @@ Status EvalGenericCommand(redis::Connection *conn, const 
std::string &body_or_sh
     *output = ReplyToRedisReply(conn, lua);
     lua_pop(lua, 2);
   }
+  conn->SetProtocolVersion(saved_protocol_version);
 
   // clean global variables to prevent information leak in function commands
   lua_pushnil(lua);
@@ -848,6 +860,28 @@ int RedisReturnSingleFieldTable(lua_State *lua, const char 
*field) {
   return 1;
 }
 
+int RedisSetResp(lua_State *lua) {
+  auto srv = GetServer(lua);
+  auto conn = srv->GetCurrentConnection();
+
+  if (lua_gettop(lua) != 1) {
+    PushError(lua, "redis.setresp() requires one argument.");
+    return RaiseError(lua);
+  }
+
+  auto resp = static_cast<int>(lua_tonumber(lua, -1));
+  if (resp != 2 && resp != 3) {
+    PushError(lua, "RESP version must be 2 or 3.");
+    return RaiseError(lua);
+  }
+  conn->SetProtocolVersion(resp == 2 ? redis::RESP::v2 : redis::RESP::v3);
+  if (resp == 3 && !srv->GetConfig()->resp3_enabled) {
+    PushError(lua, "You need set resp3-enabled to yes to enable RESP3.");
+    return RaiseError(lua);
+  }
+  return 0;
+}
+
 /* redis.error_reply() */
 int RedisErrorReplyCommand(lua_State *lua) { return 
RedisReturnSingleFieldTable(lua, "err"); }
 
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index f68db2f8..3b2dd45d 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -53,6 +53,7 @@ int RedisStatusReplyCommand(lua_State *lua);
 int RedisErrorReplyCommand(lua_State *lua);
 int RedisLogCommand(lua_State *lua);
 int RedisRegisterFunction(lua_State *lua);
+int RedisSetResp(lua_State *lua);
 
 Status CreateFunction(Server *srv, const std::string &body, std::string *sha, 
lua_State *lua, bool need_to_store);
 
diff --git a/tests/gocase/unit/scripting/scripting_test.go 
b/tests/gocase/unit/scripting/scripting_test.go
index cf4a6e34..50680f7d 100644
--- a/tests/gocase/unit/scripting/scripting_test.go
+++ b/tests/gocase/unit/scripting/scripting_test.go
@@ -483,6 +483,11 @@ math.randomseed(ARGV[1]); return tostring(math.random())
                r := rdb.Do(ctx, "EVALSHA_RO", 
"a1e63e1cd1bd1d5413851949332cfb9da4ee6dc0", "1", "foo")
                util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not 
allowed from read-only scripts")
        })
+
+       t.Run("EVAL - cannot use redis.setresp(3) if RESP3 is disabled", func(t 
*testing.T) {
+               r := rdb.Eval(ctx, `redis.setresp(3);`, []string{})
+               util.ErrorRegexp(t, r.Err(), ".*ERR.*You need set resp3-enabled 
to yes to enable RESP3.*")
+       })
 }
 
 func TestScriptingMasterSlave(t *testing.T) {
@@ -530,14 +535,14 @@ func TestScriptingWithRESP3(t *testing.T) {
        t.Run("EVAL - Redis protocol type map conversion", func(t *testing.T) {
                rdb.HSet(ctx, "myhash", "f1", "v1")
                rdb.HSet(ctx, "myhash", "f2", "v2")
-               val, err := rdb.Eval(ctx, `return redis.call('hgetall', 
KEYS[1])`, []string{"myhash"}).Result()
+               val, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('hgetall', KEYS[1])`, []string{"myhash"}).Result()
                require.NoError(t, err)
                require.Equal(t, map[interface{}]interface{}{"f1": "v1", "f2": 
"v2"}, val)
        })
 
        t.Run("EVAL - Redis protocol type set conversion", func(t *testing.T) {
                require.NoError(t, rdb.SAdd(ctx, "myset", "m0", "m1", 
"m2").Err())
-               val, err := rdb.Eval(ctx, `return redis.call('smembers', 
KEYS[1])`, []string{"myset"}).StringSlice()
+               val, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('smembers', KEYS[1])`, []string{"myset"}).StringSlice()
                require.NoError(t, err)
                slices.Sort(val)
                require.EqualValues(t, []string{"m0", "m1", "m2"}, val)
@@ -545,13 +550,13 @@ func TestScriptingWithRESP3(t *testing.T) {
 
        t.Run("EVAL - Redis protocol type double conversion", func(t 
*testing.T) {
                require.NoError(t, rdb.ZAdd(ctx, "mydouble", redis.Z{Member: 
"z0", Score: 1.5}).Err())
-               val, err := rdb.Eval(ctx, `return redis.call('zscore', KEYS[1], 
KEYS[2])`, []string{"mydouble", "z0"}).Result()
+               val, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('zscore', KEYS[1], KEYS[2])`, []string{"mydouble", "z0"}).Result()
                require.NoError(t, err)
                require.EqualValues(t, 1.5, val)
        })
 
        t.Run("EVAL - Redis protocol type bignumber conversion", func(t 
*testing.T) {
-               val, err := rdb.Eval(ctx, `return redis.call('debug', 
'protocol', 'bignum')`, []string{}).Result()
+               val, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('debug', 'protocol', 'bignum')`, []string{}).Result()
                require.NoError(t, err)
 
                bignum, _ := 
big.NewInt(0).SetString("1234567999999999999999999999999999999", 10)
@@ -559,11 +564,11 @@ func TestScriptingWithRESP3(t *testing.T) {
        })
 
        t.Run("EVAL - Redis protocol type boolean conversion", func(t 
*testing.T) {
-               val, err := rdb.Eval(ctx, `return redis.call('debug', 
'protocol', 'true')`, []string{}).Result()
+               val, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('debug', 'protocol', 'true')`, []string{}).Result()
                require.NoError(t, err)
                require.EqualValues(t, true, val)
 
-               val, err = rdb.Eval(ctx, `return redis.call('debug', 
'protocol', 'false')`, []string{}).Result()
+               val, err = rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('debug', 'protocol', 'false')`, []string{}).Result()
                require.NoError(t, err)
                require.EqualValues(t, false, val)
        })
@@ -575,4 +580,30 @@ func TestScriptingWithRESP3(t *testing.T) {
                require.EqualValues(t, "verbatim string", val)
        })
 
+       t.Run("EVAL - lua redis.setresp function", func(t *testing.T) {
+               err := rdb.Eval(ctx, `return redis.setresp(2, 3);`, 
[]string{}).Err()
+               util.ErrorRegexp(t, err, ".*ERR.*requires one argument.*")
+
+               err = rdb.Eval(ctx, `return redis.setresp(4);`, 
[]string{}).Err()
+               util.ErrorRegexp(t, err, ".*ERR.*RESP version must be 2 or 3.*")
+
+               // set to RESP3
+               err = rdb.Eval(ctx, `return redis.setresp(3);`, 
[]string{}).Err()
+               require.ErrorIs(t, err, redis.Nil)
+
+               rdb.HSet(ctx, "hash0", "f1", "v1")
+               vals, err := rdb.Eval(ctx, `redis.setresp(3); return 
redis.call('hgetall', KEYS[1])`, []string{"hash0"}).Result()
+               require.NoError(t, err)
+               // return as a map in RESP3
+               require.EqualValues(t, map[interface{}]interface{}{"f1": "v1"}, 
vals)
+
+               // set to RESP2
+               err = rdb.Eval(ctx, `return redis.setresp(2);`, 
[]string{}).Err()
+               require.ErrorIs(t, err, redis.Nil)
+
+               vals, err = rdb.Eval(ctx, `return redis.call('hgetall', 
KEYS[1])`, []string{"hash0"}).Result()
+               require.NoError(t, err)
+               // return as an array in RESP2
+               require.EqualValues(t, []interface{}{"f1", "v1"}, vals)
+       })
 }

Reply via email to