This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new b4b14f0b4 feat(script): pass storage context through scripting (#2901)
b4b14f0b4 is described below

commit b4b14f0b4bd9e4033950541daaa787dc5293e893
Author: Twice <[email protected]>
AuthorDate: Sat Apr 26 20:11:01 2025 +0800

    feat(script): pass storage context through scripting (#2901)
    
    Signed-off-by: PragmaTwice <[email protected]>
---
 src/commands/cmd_function.cc   | 11 +++++----
 src/commands/cmd_script.cc     |  2 +-
 src/server/redis_connection.cc |  2 +-
 src/storage/scripting.cc       | 52 +++++++++++++++++++++---------------------
 src/storage/scripting.h        | 29 +++++++++++++----------
 5 files changed, 51 insertions(+), 45 deletions(-)

diff --git a/src/commands/cmd_function.cc b/src/commands/cmd_function.cc
index 82773e5f7..1d5ef8461 100644
--- a/src/commands/cmd_function.cc
+++ b/src/commands/cmd_function.cc
@@ -39,7 +39,7 @@ struct CommandFunction : Commander {
       }
 
       std::string libname;
-      auto s = lua::FunctionLoad(conn, GET_OR_RET(parser.TakeStr()), true, 
replace, &libname);
+      auto s = lua::FunctionLoad(conn, &ctx, GET_OR_RET(parser.TakeStr()), 
true, replace, &libname);
       if (!s) return s;
 
       *output = SimpleString(libname);
@@ -55,21 +55,21 @@ struct CommandFunction : Commander {
         with_code = true;
       }
 
-      return lua::FunctionList(srv, conn, libname, with_code, output);
+      return lua::FunctionList(srv, conn, ctx, libname, with_code, output);
     } else if (parser.EatEqICase("listfunc")) {
       std::string funcname;
       if (parser.EatEqICase("funcname")) {
         funcname = GET_OR_RET(parser.TakeStr());
       }
 
-      return lua::FunctionListFunc(srv, conn, funcname, output);
+      return lua::FunctionListFunc(srv, conn, ctx, funcname, output);
     } else if (parser.EatEqICase("listlib")) {
       auto libname = GET_OR_RET(parser.TakeStr().Prefixed("expect a library 
name"));
 
       return lua::FunctionListLib(conn, libname, output);
     } else if (parser.EatEqICase("delete")) {
       auto libname = GET_OR_RET(parser.TakeStr());
-      if (!lua::FunctionIsLibExist(conn, libname)) {
+      if (!lua::FunctionIsLibExist(conn, &ctx, libname)) {
         return {Status::NotOK, "no such library"};
       }
       auto s = lua::FunctionDelete(ctx, conn, libname);
@@ -94,7 +94,8 @@ struct CommandFCall : Commander {
       return {Status::NotOK, "Number of keys can't be negative"};
     }
 
-    return lua::FunctionCall(conn, args_[1], 
std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
+    return lua::FunctionCall(conn, &ctx, args_[1],
+                             std::vector<std::string>(args_.begin() + 3, 
args_.begin() + 3 + numkeys),
                              std::vector<std::string>(args_.begin() + 3 + 
numkeys, args_.end()), output, read_only);
   }
 };
diff --git a/src/commands/cmd_script.cc b/src/commands/cmd_script.cc
index 8dac03981..142076de0 100644
--- a/src/commands/cmd_script.cc
+++ b/src/commands/cmd_script.cc
@@ -43,7 +43,7 @@ class CommandEvalImpl : public Commander {
     }
 
     return lua::EvalGenericCommand(
-        conn, args_[1], std::vector<std::string>(args_.begin() + 3, 
args_.begin() + 3 + numkeys),
+        conn, &ctx, args_[1], std::vector<std::string>(args_.begin() + 3, 
args_.begin() + 3 + numkeys),
         std::vector<std::string>(args_.begin() + 3 + numkeys, args_.end()), 
evalsha, output, read_only);
   }
 };
diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index 765a0e517..aa21a5595 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -373,7 +373,7 @@ static bool IsCmdAllowedInStaleData(const std::string 
&cmd_name) {
 void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
   const Config *config = srv_->GetConfig();
   std::string reply;
-  std::string password = config->requirepass;
+  const std::string &password = config->requirepass;
 
   while (!to_process_cmds->empty()) {
     CommandTokens cmd_tokens = std::move(to_process_cmds->front());
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index 26bd34548..ebd194a23 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -281,8 +281,8 @@ int RedisRegisterFunction(lua_State *lua) {
   return 0;
 }
 
-Status FunctionLoad(redis::Connection *conn, const std::string &script, bool 
need_to_store, bool replace,
-                    [[maybe_unused]] std::string *lib_name, bool read_only) {
+Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const 
std::string &script, bool need_to_store,
+                    bool replace, [[maybe_unused]] std::string *lib_name, bool 
read_only) {
   std::string first_line, lua_code;
   if (auto pos = script.find('\n'); pos != std::string::npos) {
     first_line = script.substr(0, pos);
@@ -296,17 +296,17 @@ Status FunctionLoad(redis::Connection *conn, const 
std::string &script, bool nee
   auto srv = conn->GetServer();
   auto lua = conn->Owner()->Lua();
 
-  if (FunctionIsLibExist(conn, libname, need_to_store, read_only)) {
+  if (FunctionIsLibExist(conn, ctx, libname, need_to_store, read_only)) {
     if (!replace) {
       return {Status::NotOK, "library already exists, please specify REPLACE 
to force load"};
     }
-    engine::Context ctx(srv->storage);
-    auto s = FunctionDelete(ctx, conn, libname);
+    auto s = FunctionDelete(*ctx, conn, libname);
     if (!s) return s;
   }
 
   ScriptRunCtx script_run_ctx;
   script_run_ctx.conn = conn;
+  script_run_ctx.ctx = ctx;
   script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
 
   SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
@@ -339,14 +339,15 @@ Status FunctionLoad(redis::Connection *conn, const 
std::string &script, bool nee
 
   RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
 
-  if (!FunctionIsLibExist(conn, libname, false, read_only)) {
+  if (!FunctionIsLibExist(conn, ctx, libname, false, read_only)) {
     return {Status::NotOK, "Please register some function in FUNCTION LOAD"};
   }
 
   return need_to_store ? srv->FunctionSetCode(libname, script) : Status::OK();
 }
 
-bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, 
bool need_check_storage, bool read_only) {
+bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const 
std::string &libname,
+                        bool need_check_storage, bool read_only) {
   auto srv = conn->GetServer();
   auto lua = conn->Owner()->Lua();
 
@@ -373,14 +374,15 @@ bool FunctionIsLibExist(redis::Connection *conn, const 
std::string &libname, boo
   if (!s) return false;
 
   std::string lib_name;
-  s = FunctionLoad(conn, code, false, false, &lib_name, read_only);
+  s = FunctionLoad(conn, ctx, code, false, false, &lib_name, read_only);
   return static_cast<bool>(s);
 }
 
 // FunctionCall will firstly find the function in the lua runtime,
 // if it is not found, it will try to load the library where the function is 
located from storage
-Status FunctionCall(redis::Connection *conn, const std::string &name, const 
std::vector<std::string> &keys,
-                    const std::vector<std::string> &argv, std::string *output, 
bool read_only) {
+Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const 
std::string &name,
+                    const std::vector<std::string> &keys, const 
std::vector<std::string> &argv, std::string *output,
+                    bool read_only) {
   auto srv = conn->GetServer();
   auto lua = conn->Owner()->Lua();
 
@@ -397,7 +399,7 @@ Status FunctionCall(redis::Connection *conn, const 
std::string &name, const std:
     std::string libcode;
     s = srv->FunctionGetCode(libname, &libcode);
     if (!s) return s;
-    s = FunctionLoad(conn, libcode, false, false, &libname, read_only);
+    s = FunctionLoad(conn, ctx, libcode, false, false, &libname, read_only);
     if (!s) return s;
 
     lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
@@ -405,6 +407,7 @@ Status FunctionCall(redis::Connection *conn, const 
std::string &name, const std:
 
   ScriptRunCtx script_run_ctx;
   script_run_ctx.conn = conn;
+  script_run_ctx.ctx = ctx;
   script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
   lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
   if (!lua_isnil(lua, -1)) {
@@ -447,13 +450,12 @@ Status FunctionCall(redis::Connection *conn, const 
std::string &name, const std:
 }
 
 // list all library names and their code (enabled via `with_code`)
-Status FunctionList(Server *srv, const redis::Connection *conn, const 
std::string &libname, bool with_code,
-                    std::string *output) {
+Status FunctionList(Server *srv, const redis::Connection *conn, 
engine::Context &ctx, const std::string &libname,
+                    bool with_code, std::string *output) {
   std::string start_key = engine::kLuaLibCodePrefix + libname;
   std::string end_key = start_key;
   end_key.back()++;
 
-  engine::Context ctx(srv->storage);
   rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
   rocksdb::Slice upper_bound(end_key);
   read_options.iterate_upper_bound = &upper_bound;
@@ -487,12 +489,12 @@ Status FunctionList(Server *srv, const redis::Connection 
*conn, const std::strin
 
 // extension to Redis Function
 // list all function names and their corresponding library names
-Status FunctionListFunc(Server *srv, const redis::Connection *conn, const 
std::string &funcname, std::string *output) {
+Status FunctionListFunc(Server *srv, const redis::Connection *conn, 
engine::Context &ctx, const std::string &funcname,
+                        std::string *output) {
   std::string start_key = engine::kLuaFuncLibPrefix + funcname;
   std::string end_key = start_key;
   end_key.back()++;
 
-  engine::Context ctx(srv->storage);
   rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
   rocksdb::Slice upper_bound(end_key);
   read_options.iterate_upper_bound = &upper_bound;
@@ -603,8 +605,9 @@ Status FunctionDelete(engine::Context &ctx, 
redis::Connection *conn, const std::
   return Status::OK();
 }
 
-Status EvalGenericCommand(redis::Connection *conn, const std::string 
&body_or_sha, const std::vector<std::string> &keys,
-                          const std::vector<std::string> &argv, bool evalsha, 
std::string *output, bool read_only) {
+Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const 
std::string &body_or_sha,
+                          const std::vector<std::string> &keys, const 
std::vector<std::string> &argv, bool evalsha,
+                          std::string *output, bool read_only) {
   Server *srv = conn->GetServer();
   // Use the worker's private Lua VM when entering the read-only mode
   lua_State *lua = conn->Owner()->Lua();
@@ -652,6 +655,7 @@ Status EvalGenericCommand(redis::Connection *conn, const 
std::string &body_or_sh
 
   ScriptRunCtx current_script_run_ctx;
   current_script_run_ctx.conn = conn;
+  current_script_run_ctx.ctx = ctx;
   current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 
0;
   lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 
2).c_str());
   if (!lua_isnil(lua, -1)) {
@@ -820,14 +824,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
   }
 
   std::string output;
-  // TODO: make it possible for multiple redis commands in lua script to use 
the same txn context.
-  {
-    engine::Context ctx(srv->storage);
-    s = conn->ExecuteCommand(ctx, cmd_name, args, cmd.get(), &output);
-    if (!s) {
-      PushError(lua, s.Msg().data());
-      return raise_error ? RaiseError(lua) : 1;
-    }
+  s = conn->ExecuteCommand(*script_run_ctx->ctx, cmd_name, args, cmd.get(), 
&output);
+  if (!s) {
+    PushError(lua, s.Msg().data());
+    return raise_error ? RaiseError(lua) : 1;
   }
 
   srv->FeedMonitorConns(conn, args);
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 3e7254528..f41cf6e9a 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -26,6 +26,7 @@
 #include "lua.hpp"
 #include "server/redis_connection.h"
 #include "status.h"
+#include "storage/storage.h"
 
 namespace engine {
 struct Context;
@@ -62,23 +63,25 @@ int RedisSetResp(lua_State *lua);
 
 Status CreateFunction(Server *srv, const std::string &body, std::string *sha, 
lua_State *lua, bool need_to_store);
 
-Status EvalGenericCommand(redis::Connection *conn, const std::string 
&body_or_sha, const std::vector<std::string> &keys,
-                          const std::vector<std::string> &argv, bool evalsha, 
std::string *output,
-                          bool read_only = false);
+Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const 
std::string &body_or_sha,
+                          const std::vector<std::string> &keys, const 
std::vector<std::string> &argv, bool evalsha,
+                          std::string *output, bool read_only = false);
 
 bool ScriptExists(lua_State *lua, const std::string &sha);
 
-Status FunctionLoad(redis::Connection *conn, const std::string &script, bool 
need_to_store, bool replace,
-                    std::string *lib_name, bool read_only = false);
-Status FunctionCall(redis::Connection *conn, const std::string &name, const 
std::vector<std::string> &keys,
-                    const std::vector<std::string> &argv, std::string *output, 
bool read_only = false);
-Status FunctionList(Server *srv, const redis::Connection *conn, const 
std::string &libname, bool with_code,
-                    std::string *output);
-Status FunctionListFunc(Server *srv, const redis::Connection *conn, const 
std::string &funcname, std::string *output);
+Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const 
std::string &script, bool need_to_store,
+                    bool replace, std::string *lib_name, bool read_only = 
false);
+Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const 
std::string &name,
+                    const std::vector<std::string> &keys, const 
std::vector<std::string> &argv, std::string *output,
+                    bool read_only = false);
+Status FunctionList(Server *srv, const redis::Connection *conn, 
engine::Context &ctx, const std::string &libname,
+                    bool with_code, std::string *output);
+Status FunctionListFunc(Server *srv, const redis::Connection *conn, 
engine::Context &ctx, const std::string &funcname,
+                        std::string *output);
 Status FunctionListLib(redis::Connection *conn, const std::string &libname, 
std::string *output);
 Status FunctionDelete(engine::Context &ctx, redis::Connection *conn, const 
std::string &name);
-bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, 
bool need_check_storage = true,
-                        bool read_only = false);
+bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const 
std::string &libname,
+                        bool need_check_storage = true, bool read_only = 
false);
 
 const char *RedisProtocolToLuaType(lua_State *lua, const char *reply);
 const char *RedisProtocolToLuaTypeInt(lua_State *lua, const char *reply);
@@ -150,6 +153,8 @@ struct ScriptRunCtx {
   int current_slot = -1;
   // the current connection
   redis::Connection *conn = nullptr;
+  // the storage context
+  engine::Context *ctx = nullptr;
 };
 
 /// SaveOnRegistry saves user-defined data to lua REGISTRY

Reply via email to