This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 55e99f22 fix(script): avoid SetCurrentConnection on read-only 
scripting (#2640)
55e99f22 is described below

commit 55e99f22f8a28cf64390963ef044e456f17fafb0
Author: Twice <[email protected]>
AuthorDate: Sun Nov 3 19:38:09 2024 +0800

    fix(script): avoid SetCurrentConnection on read-only scripting (#2640)
---
 src/server/redis_connection.cc | 12 ------------
 src/server/server.cc           |  4 ++--
 src/server/server.h            |  5 -----
 src/server/worker.cc           |  2 +-
 src/storage/scripting.cc       | 38 +++++++++++++++++++++-----------------
 src/storage/scripting.h        |  6 +++---
 6 files changed, 27 insertions(+), 40 deletions(-)

diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index 584b5068..70abfe70 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -417,22 +417,10 @@ void 
Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
       // No lock guard, because 'exec' command has acquired 
'WorkExclusivityGuard'
     } else if (cmd_flags & kCmdExclusive) {
       exclusivity = srv_->WorkExclusivityGuard();
-
-      // When executing lua script commands that have "exclusive" attribute, 
we need to know current connection,
-      // but we should set current connection after acquiring the 
WorkExclusivityGuard to make it thread-safe
-      srv_->SetCurrentConnection(this);
     } else {
       concurrency = srv_->WorkConcurrencyGuard();
     }
 
-    auto category = attributes->category;
-    if ((category == CommandCategory::Function || category == 
CommandCategory::Script) && (cmd_flags & kCmdReadOnly)) {
-      // FIXME: since read-only script commands are not exclusive,
-      // SetCurrentConnection here is weird and can cause many issues,
-      // we should pass the Connection directly to the lua context instead
-      srv_->SetCurrentConnection(this);
-    }
-
     if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) {
       Reply(redis::Error({Status::RedisLoading, errRestoringBackup}));
       if (is_multi_exec) multi_error_ = true;
diff --git a/src/server/server.cc b/src/server/server.cc
index e569d12e..5b52eb33 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -102,7 +102,7 @@ Server::Server(engine::Storage *storage, Config *config)
   AdjustOpenFilesLimit();
   slow_log_.SetMaxEntries(config->slowlog_max_len);
   perf_log_.SetMaxEntries(config->profiling_sample_record_max_len);
-  lua_ = lua::CreateState(this);
+  lua_ = lua::CreateState();
 }
 
 Server::~Server() {
@@ -1764,7 +1764,7 @@ Status Server::FunctionSetLib(const std::string &func, 
const std::string &lib) c
 }
 
 void Server::ScriptReset() {
-  auto lua = lua_.exchange(lua::CreateState(this));
+  auto lua = lua_.exchange(lua::CreateState());
   lua::DestroyState(lua);
 }
 
diff --git a/src/server/server.h b/src/server/server.h
index 7d8c8327..3c2ab16e 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -285,9 +285,6 @@ class Server {
   Status ExecPropagatedCommand(const std::vector<std::string> &tokens);
   Status ExecPropagateScriptCommand(const std::vector<std::string> &tokens);
 
-  void SetCurrentConnection(redis::Connection *conn) { curr_connection_ = 
conn; }
-  redis::Connection *GetCurrentConnection() { return curr_connection_; }
-
   LogCollector<PerfEntry> *GetPerfLog() { return &perf_log_; }
   LogCollector<SlowEntry> *GetSlowLog() { return &slow_log_; }
   void SlowlogPushEntryIfNeeded(const std::vector<std::string> *args, uint64_t 
duration, const redis::Connection *conn);
@@ -343,8 +340,6 @@ class Server {
 
   std::atomic<lua_State *> lua_;
 
-  redis::Connection *curr_connection_ = nullptr;
-
   // client counters
   std::atomic<uint64_t> client_id_{1};
   std::atomic<int> connected_clients_{0};
diff --git a/src/server/worker.cc b/src/server/worker.cc
index 4ddf31ad..6420d76c 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -76,7 +76,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv), 
base_(event_base_new())
       }
     }
   }
-  lua_ = lua::CreateState(srv);
+  lua_ = lua::CreateState();
 }
 
 Worker::~Worker() {
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index f38d9422..5768aee8 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -57,15 +57,12 @@ enum {
 
 namespace lua {
 
-lua_State *CreateState(Server *srv) {
+lua_State *CreateState() {
   lua_State *lua = lua_open();
   LoadLibraries(lua);
   RemoveUnsupportedFunctions(lua);
   LoadFuncs(lua);
 
-  lua_pushlightuserdata(lua, srv);
-  lua_setglobal(lua, REDIS_LUA_SERVER_PTR);
-
   EnableGlobalsProtection(lua);
   return lua;
 }
@@ -273,7 +270,10 @@ int RedisRegisterFunction(lua_State *lua) {
   }
 
   // store the map from function name to library name
-  auto s = GetServer(lua)->FunctionSetLib(name, libname);
+  auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, 
REGISTRY_SCRIPT_RUN_CTX_NAME);
+  CHECK_NOTNULL(script_run_ctx);
+
+  auto s = script_run_ctx->conn->GetServer()->FunctionSetLib(name, libname);
   if (!s) {
     lua_pushstring(lua, "redis.register_function() failed to store 
informantion.");
     return lua_error(lua);
@@ -305,6 +305,12 @@ Status FunctionLoad(redis::Connection *conn, const 
std::string &script, bool nee
     if (!s) return s;
   }
 
+  ScriptRunCtx script_run_ctx;
+  script_run_ctx.conn = conn;
+  script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
+
+  SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
+
   lua_pushstring(lua, libname.c_str());
   lua_setglobal(lua, REDIS_FUNCTION_LIBNAME);
   auto libname_exit = MakeScopeExit([lua] {
@@ -331,6 +337,8 @@ Status FunctionLoad(redis::Connection *conn, const 
std::string &script, bool nee
     return {Status::NotOK, "Error while running new function lib: " + err_msg};
   }
 
+  RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
+
   if (!FunctionIsLibExist(conn, libname, false, read_only)) {
     return {Status::NotOK, "Please register some function in FUNCTION LOAD"};
   }
@@ -396,6 +404,7 @@ Status FunctionCall(redis::Connection *conn, const 
std::string &name, const std:
   }
 
   ScriptRunCtx script_run_ctx;
+  script_run_ctx.conn = conn;
   script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
   lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
   if (!lua_isnil(lua, -1)) {
@@ -642,6 +651,7 @@ Status EvalGenericCommand(redis::Connection *conn, const 
std::string &body_or_sh
   }
 
   ScriptRunCtx current_script_run_ctx;
+  current_script_run_ctx.conn = conn;
   current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 
0;
   lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 
2).c_str());
   if (!lua_isnil(lua, -1)) {
@@ -709,14 +719,6 @@ int RedisCallCommand(lua_State *lua) { return 
RedisGenericCommand(lua, 1); }
 
 int RedisPCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 0); }
 
-Server *GetServer(lua_State *lua) {
-  lua_getglobal(lua, REDIS_LUA_SERVER_PTR);
-  auto srv = reinterpret_cast<Server *>(lua_touserdata(lua, -1));
-  lua_pop(lua, 1);
-
-  return srv;
-}
-
 // TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
 // so the function need to be refactored
 int RedisGenericCommand(lua_State *lua, int raise_error) {
@@ -772,10 +774,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
 
   std::string cmd_name = attributes->name;
 
-  auto srv = GetServer(lua);
+  auto *conn = script_run_ctx->conn;
+  auto *srv = conn->GetServer();
   Config *config = srv->GetConfig();
 
-  redis::Connection *conn = srv->GetCurrentConnection();
   if (config->cluster_enabled) {
     if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
       PushError(lua, "Can not run script on cluster, 'no-cluster' flag is 
set");
@@ -901,8 +903,10 @@ int RedisReturnSingleFieldTable(lua_State *lua, const char 
*field) {
 }
 
 int RedisSetResp(lua_State *lua) {
-  auto srv = GetServer(lua);
-  auto conn = srv->GetCurrentConnection();
+  auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua, 
REGISTRY_SCRIPT_RUN_CTX_NAME);
+  CHECK_NOTNULL(script_run_ctx);
+  auto *conn = script_run_ctx->conn;
+  auto *srv = conn->GetServer();
 
   if (lua_gettop(lua) != 1) {
     PushError(lua, "redis.setresp() requires one argument.");
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 9aa4044b..188f855c 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -35,7 +35,6 @@ inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = 
"f_";
 inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_";
 inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] = 
"__redis_registered_";
 inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] = 
"__redis_registered_flags_";
-inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr";
 inline constexpr const char REDIS_FUNCTION_LIBNAME[] = 
"REDIS_FUNCTION_LIBNAME";
 inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] = 
"REDIS_FUNCTION_NEEDSTORE";
 inline constexpr const char REDIS_FUNCTION_LIBRARIES[] = 
"REDIS_FUNCTION_LIBRARIES";
@@ -43,9 +42,8 @@ inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = 
"SCRIPT_RUN_CTX";
 
 namespace lua {
 
-lua_State *CreateState(Server *srv);
+lua_State *CreateState();
 void DestroyState(lua_State *lua);
-Server *GetServer(lua_State *lua);
 
 void LoadFuncs(lua_State *lua);
 void LoadLibraries(lua_State *lua);
@@ -150,6 +148,8 @@ struct ScriptRunCtx {
   // and is used to detect whether there is cross-slot access
   // between multiple commands in a script or function.
   int current_slot = -1;
+  // the current connection
+  redis::Connection *conn = nullptr;
 };
 
 /// SaveOnRegistry saves user-defined data to lua REGISTRY

Reply via email to