This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new db2d244f Add support of RESP3 in Lua (#2119)
db2d244f is described below
commit db2d244ffd8ef9c5fdb8e4c3c346cef4a195da02
Author: hulk <[email protected]>
AuthorDate: Thu Feb 29 19:00:38 2024 +0800
Add support of RESP3 in Lua (#2119)
---
src/storage/scripting.cc | 160 ++++++++++++++++++++++++--
src/storage/scripting.h | 2 +
tests/gocase/unit/scripting/scripting_test.go | 64 +++++++++++
3 files changed, 217 insertions(+), 9 deletions(-)
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index e4c6d1fb..cd024940 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -955,6 +955,12 @@ const char *RedisProtocolToLuaType(lua_State *lua, const
char *reply) {
case ',':
p = RedisProtocolToLuaTypeDouble(lua, reply);
break;
+ case '(':
+ p = RedisProtocolToLuaTypeBigNumber(lua, reply);
+ break;
+ case '=':
+ p = RedisProtocolToLuaTypeVerbatimString(lua, reply);
+ break;
}
return p;
}
@@ -1009,13 +1015,36 @@ const char *RedisProtocolToLuaTypeAggregate(lua_State
*lua, const char *reply, i
lua_pushboolean(lua, 0);
return p;
}
- lua_newtable(lua);
- for (j = 0; j < mbulklen; j++) {
- lua_pushnumber(lua, j + 1);
- p = RedisProtocolToLuaType(lua, p);
+ if (atype == '*') {
+ lua_newtable(lua);
+ for (j = 0; j < mbulklen; j++) {
+ lua_pushnumber(lua, j + 1);
+ p = RedisProtocolToLuaType(lua, p);
+ lua_settable(lua, -3);
+ }
+ return p;
+ }
+
+ CHECK(atype == '%' || atype == '~');
+ if (atype == '%' || atype == '~') {
+ lua_newtable(lua);
+ lua_pushstring(lua, atype == '%' ? "map" : "set");
+ lua_newtable(lua);
+ for (j = 0; j < mbulklen; j++) {
+ p = RedisProtocolToLuaType(lua, p);
+ if (atype == '%') { // map
+ p = RedisProtocolToLuaType(lua, p);
+ } else { // set
+ lua_pushboolean(lua, 1);
+ }
+ lua_settable(lua, -3);
+ }
lua_settable(lua, -3);
+ return p;
}
- return p;
+
+ // Unreachable, return the original position if it did reach here.
+ return reply;
}
const char *RedisProtocolToLuaTypeNull(lua_State *lua, const char *reply) {
@@ -1051,6 +1080,36 @@ const char *RedisProtocolToLuaTypeDouble(lua_State *lua,
const char *reply) {
return p + 2;
}
+const char *RedisProtocolToLuaTypeBigNumber(lua_State *lua, const char *reply)
{
+ const char *p = strchr(reply + 1, '\r');
+ lua_newtable(lua);
+ lua_pushstring(lua, "big_number");
+ lua_pushlstring(lua, reply + 1, p - reply - 1);
+ lua_settable(lua, -3);
+ return p + 2;
+}
+
+const char *RedisProtocolToLuaTypeVerbatimString(lua_State *lua, const char
*reply) {
+ const char *p = strchr(reply + 1, '\r');
+ int64_t bulklen = ParseInt<int64_t>(std::string(reply + 1, p - reply - 1),
10).ValueOr(0);
+ p += 2; // skip \r\n
+
+ lua_newtable(lua);
+ lua_pushstring(lua, "verbatim_string");
+
+ lua_newtable(lua);
+ lua_pushstring(lua, "string");
+ lua_pushlstring(lua, p + 4, bulklen - 4);
+ lua_settable(lua, -3);
+
+ lua_pushstring(lua, "format");
+ lua_pushlstring(lua, p, 3);
+ lua_settable(lua, -3);
+
+ lua_settable(lua, -3);
+ return p + bulklen + 2;
+}
+
/* This function is used in order to push an error on the Lua stack in the
* format used by redis.pcall to return errors, which is a lua table
* with a single "err" field set to the error string. Note that this
@@ -1094,7 +1153,7 @@ std::string ReplyToRedisReply(redis::Connection *conn,
lua_State *lua) {
/* Handle error reply. */
lua_pushstring(lua, "err");
- lua_gettable(lua, -2);
+ lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
output = redis::Error(lua_tostring(lua, -1));
@@ -1105,7 +1164,7 @@ std::string ReplyToRedisReply(redis::Connection *conn,
lua_State *lua) {
/* Handle status reply. */
lua_pushstring(lua, "ok");
- lua_gettable(lua, -2);
+ lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
obj_s = lua_tolstring(lua, -1, &obj_len);
@@ -1115,9 +1174,20 @@ std::string ReplyToRedisReply(redis::Connection *conn,
lua_State *lua) {
}
lua_pop(lua, 1); /* Discard the 'ok' field value we pushed */
+ /* Handle double reply. */
+ lua_pushstring(lua, "double");
+ lua_rawget(lua, -2);
+ t = lua_type(lua, -1);
+ if (t == LUA_TNUMBER) {
+ output = conn->Double(lua_tonumber(lua, -1));
+ lua_pop(lua, 1);
+ return output;
+ }
+ lua_pop(lua, 1); /* Discard the 'double' field value we pushed */
+
/* Handle big number reply. */
lua_pushstring(lua, "big_number");
- lua_gettable(lua, -2);
+ lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TSTRING) {
obj_s = lua_tolstring(lua, -1, &obj_len);
@@ -1127,10 +1197,82 @@ std::string ReplyToRedisReply(redis::Connection *conn,
lua_State *lua) {
}
lua_pop(lua, 1); /* Discard the 'big_number' field value we pushed */
+ /* Handle verbatim reply. */
+ lua_pushstring(lua, "verbatim_string");
+ lua_rawget(lua, -2);
+ t = lua_type(lua, -1);
+ if (t == LUA_TTABLE) {
+ lua_pushstring(lua, "format");
+ lua_rawget(lua, -2);
+ t = lua_type(lua, -1);
+ if (t == LUA_TSTRING) {
+ const char *format = lua_tostring(lua, -1);
+ lua_pushstring(lua, "string");
+ lua_rawget(lua, -3);
+ t = lua_type(lua, -1);
+ if (t == LUA_TSTRING) {
+ obj_s = lua_tolstring(lua, -1, &obj_len);
+ output = conn->VerbatimString(std::string(format),
std::string(obj_s, obj_len));
+ lua_pop(lua, 4);
+ return output;
+ }
+ // discard 'string'
+ lua_pop(lua, 1);
+ }
+ // discard 'format'
+ lua_pop(lua, 1);
+ }
+ lua_pop(lua, 1); /* Discard the 'verbatim_string' field value we pushed
*/
+
+ /* Handle map reply. */
+ lua_pushstring(lua, "map");
+ lua_rawget(lua, -2);
+ t = lua_type(lua, -1);
+ if (t == LUA_TTABLE) {
+ int map_len = 0;
+ std::string map_output;
+ lua_pushnil(lua);
+ while (lua_next(lua, -2)) {
+ lua_pushvalue(lua, -2);
+ // return key
+ map_output += ReplyToRedisReply(conn, lua);
+ lua_pop(lua, 1);
+ // return value
+ map_output += ReplyToRedisReply(conn, lua);
+ lua_pop(lua, 1);
+ map_len++;
+ }
+ output = conn->HeaderOfMap(map_len) + std::move(map_output);
+ lua_pop(lua, 1);
+ return output;
+ }
+ lua_pop(lua, 1); /* Discard the 'map' field value we pushed */
+
+ /* Handle set reply. */
+ lua_pushstring(lua, "set");
+ lua_rawget(lua, -2);
+ t = lua_type(lua, -1);
+ if (t == LUA_TTABLE) {
+ int set_len = 0;
+ std::string set_output;
+ lua_pushnil(lua);
+ while (lua_next(lua, -2)) {
+ lua_pop(lua, 1);
+ lua_pushvalue(lua, -1);
+ set_output += ReplyToRedisReply(conn, lua);
+ lua_pop(lua, 1);
+ set_len++;
+ }
+ output = conn->HeaderOfSet(set_len) + std::move(set_output);
+ lua_pop(lua, 1);
+ return output;
+ }
+ lua_pop(lua, 1); /* Discard the 'set' field value we pushed */
+
j = 1, mbulklen = 0;
while (true) {
lua_pushnumber(lua, j++);
- lua_gettable(lua, -2);
+ lua_rawget(lua, -2);
t = lua_type(lua, -1);
if (t == LUA_TNIL) {
lua_pop(lua, 1);
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index a2c90b90..f68db2f8 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -83,6 +83,8 @@ const char *RedisProtocolToLuaTypeAggregate(lua_State *lua,
const char *reply, i
const char *RedisProtocolToLuaTypeNull(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeBool(lua_State *lua, const char *reply, int
tf);
const char *RedisProtocolToLuaTypeDouble(lua_State *lua, const char *reply);
+const char *RedisProtocolToLuaTypeBigNumber(lua_State *lua, const char *reply);
+const char *RedisProtocolToLuaTypeVerbatimString(lua_State *lua, const char
*reply);
std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua);
diff --git a/tests/gocase/unit/scripting/scripting_test.go
b/tests/gocase/unit/scripting/scripting_test.go
index 6f8180bb..cf4a6e34 100644
--- a/tests/gocase/unit/scripting/scripting_test.go
+++ b/tests/gocase/unit/scripting/scripting_test.go
@@ -22,11 +22,13 @@ package scripting
import (
"context"
"fmt"
+ "math/big"
"testing"
"github.com/apache/kvrocks/tests/gocase/util"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
+ "golang.org/x/exp/slices"
)
func TestScripting(t *testing.T) {
@@ -512,3 +514,65 @@ func TestScriptingMasterSlave(t *testing.T) {
require.Equal(t, []bool{false}, slaveClient.ScriptExists(ctx,
sha).Val())
})
}
+
+func TestScriptingWithRESP3(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{
+ "resp3-enabled": "yes",
+ })
+ defer srv.Close()
+
+ rdb := srv.NewClient()
+ defer func() {
+ require.NoError(t, rdb.Close())
+ }()
+
+ ctx := context.Background()
+ t.Run("EVAL - Redis protocol type map conversion", func(t *testing.T) {
+ rdb.HSet(ctx, "myhash", "f1", "v1")
+ rdb.HSet(ctx, "myhash", "f2", "v2")
+ val, err := rdb.Eval(ctx, `return redis.call('hgetall',
KEYS[1])`, []string{"myhash"}).Result()
+ require.NoError(t, err)
+ require.Equal(t, map[interface{}]interface{}{"f1": "v1", "f2":
"v2"}, val)
+ })
+
+ t.Run("EVAL - Redis protocol type set conversion", func(t *testing.T) {
+ require.NoError(t, rdb.SAdd(ctx, "myset", "m0", "m1",
"m2").Err())
+ val, err := rdb.Eval(ctx, `return redis.call('smembers',
KEYS[1])`, []string{"myset"}).StringSlice()
+ require.NoError(t, err)
+ slices.Sort(val)
+ require.EqualValues(t, []string{"m0", "m1", "m2"}, val)
+ })
+
+ t.Run("EVAL - Redis protocol type double conversion", func(t
*testing.T) {
+ require.NoError(t, rdb.ZAdd(ctx, "mydouble", redis.Z{Member:
"z0", Score: 1.5}).Err())
+ val, err := rdb.Eval(ctx, `return redis.call('zscore', KEYS[1],
KEYS[2])`, []string{"mydouble", "z0"}).Result()
+ require.NoError(t, err)
+ require.EqualValues(t, 1.5, val)
+ })
+
+ t.Run("EVAL - Redis protocol type bignumber conversion", func(t
*testing.T) {
+ val, err := rdb.Eval(ctx, `return redis.call('debug',
'protocol', 'bignum')`, []string{}).Result()
+ require.NoError(t, err)
+
+ bignum, _ :=
big.NewInt(0).SetString("1234567999999999999999999999999999999", 10)
+ require.EqualValues(t, bignum, val)
+ })
+
+ t.Run("EVAL - Redis protocol type boolean conversion", func(t
*testing.T) {
+ val, err := rdb.Eval(ctx, `return redis.call('debug',
'protocol', 'true')`, []string{}).Result()
+ require.NoError(t, err)
+ require.EqualValues(t, true, val)
+
+ val, err = rdb.Eval(ctx, `return redis.call('debug',
'protocol', 'false')`, []string{}).Result()
+ require.NoError(t, err)
+ require.EqualValues(t, false, val)
+ })
+
+ t.Run("EVAL - Redis protocol type verbatim conversion", func(t
*testing.T) {
+ val, err := rdb.Eval(ctx, `return redis.call('debug',
'protocol', 'verbatim')`, []string{}).Result()
+ require.NoError(t, err)
+
+ require.EqualValues(t, "verbatim string", val)
+ })
+
+}