This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 55e99f22 fix(script): avoid SetCurrentConnection on read-only
scripting (#2640)
55e99f22 is described below
commit 55e99f22f8a28cf64390963ef044e456f17fafb0
Author: Twice <[email protected]>
AuthorDate: Sun Nov 3 19:38:09 2024 +0800
fix(script): avoid SetCurrentConnection on read-only scripting (#2640)
---
src/server/redis_connection.cc | 12 ------------
src/server/server.cc | 4 ++--
src/server/server.h | 5 -----
src/server/worker.cc | 2 +-
src/storage/scripting.cc | 38 +++++++++++++++++++++-----------------
src/storage/scripting.h | 6 +++---
6 files changed, 27 insertions(+), 40 deletions(-)
diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index 584b5068..70abfe70 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -417,22 +417,10 @@ void
Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
// No lock guard, because 'exec' command has acquired
'WorkExclusivityGuard'
} else if (cmd_flags & kCmdExclusive) {
exclusivity = srv_->WorkExclusivityGuard();
-
- // When executing lua script commands that have "exclusive" attribute,
we need to know current connection,
- // but we should set current connection after acquiring the
WorkExclusivityGuard to make it thread-safe
- srv_->SetCurrentConnection(this);
} else {
concurrency = srv_->WorkConcurrencyGuard();
}
- auto category = attributes->category;
- if ((category == CommandCategory::Function || category ==
CommandCategory::Script) && (cmd_flags & kCmdReadOnly)) {
- // FIXME: since read-only script commands are not exclusive,
- // SetCurrentConnection here is weird and can cause many issues,
- // we should pass the Connection directly to the lua context instead
- srv_->SetCurrentConnection(this);
- }
-
if (srv_->IsLoading() && !(cmd_flags & kCmdLoading)) {
Reply(redis::Error({Status::RedisLoading, errRestoringBackup}));
if (is_multi_exec) multi_error_ = true;
diff --git a/src/server/server.cc b/src/server/server.cc
index e569d12e..5b52eb33 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -102,7 +102,7 @@ Server::Server(engine::Storage *storage, Config *config)
AdjustOpenFilesLimit();
slow_log_.SetMaxEntries(config->slowlog_max_len);
perf_log_.SetMaxEntries(config->profiling_sample_record_max_len);
- lua_ = lua::CreateState(this);
+ lua_ = lua::CreateState();
}
Server::~Server() {
@@ -1764,7 +1764,7 @@ Status Server::FunctionSetLib(const std::string &func,
const std::string &lib) c
}
void Server::ScriptReset() {
- auto lua = lua_.exchange(lua::CreateState(this));
+ auto lua = lua_.exchange(lua::CreateState());
lua::DestroyState(lua);
}
diff --git a/src/server/server.h b/src/server/server.h
index 7d8c8327..3c2ab16e 100644
--- a/src/server/server.h
+++ b/src/server/server.h
@@ -285,9 +285,6 @@ class Server {
Status ExecPropagatedCommand(const std::vector<std::string> &tokens);
Status ExecPropagateScriptCommand(const std::vector<std::string> &tokens);
- void SetCurrentConnection(redis::Connection *conn) { curr_connection_ =
conn; }
- redis::Connection *GetCurrentConnection() { return curr_connection_; }
-
LogCollector<PerfEntry> *GetPerfLog() { return &perf_log_; }
LogCollector<SlowEntry> *GetSlowLog() { return &slow_log_; }
void SlowlogPushEntryIfNeeded(const std::vector<std::string> *args, uint64_t
duration, const redis::Connection *conn);
@@ -343,8 +340,6 @@ class Server {
std::atomic<lua_State *> lua_;
- redis::Connection *curr_connection_ = nullptr;
-
// client counters
std::atomic<uint64_t> client_id_{1};
std::atomic<int> connected_clients_{0};
diff --git a/src/server/worker.cc b/src/server/worker.cc
index 4ddf31ad..6420d76c 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -76,7 +76,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv),
base_(event_base_new())
}
}
}
- lua_ = lua::CreateState(srv);
+ lua_ = lua::CreateState();
}
Worker::~Worker() {
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index f38d9422..5768aee8 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -57,15 +57,12 @@ enum {
namespace lua {
-lua_State *CreateState(Server *srv) {
+lua_State *CreateState() {
lua_State *lua = lua_open();
LoadLibraries(lua);
RemoveUnsupportedFunctions(lua);
LoadFuncs(lua);
- lua_pushlightuserdata(lua, srv);
- lua_setglobal(lua, REDIS_LUA_SERVER_PTR);
-
EnableGlobalsProtection(lua);
return lua;
}
@@ -273,7 +270,10 @@ int RedisRegisterFunction(lua_State *lua) {
}
// store the map from function name to library name
- auto s = GetServer(lua)->FunctionSetLib(name, libname);
+ auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua,
REGISTRY_SCRIPT_RUN_CTX_NAME);
+ CHECK_NOTNULL(script_run_ctx);
+
+ auto s = script_run_ctx->conn->GetServer()->FunctionSetLib(name, libname);
if (!s) {
lua_pushstring(lua, "redis.register_function() failed to store
informantion.");
return lua_error(lua);
@@ -305,6 +305,12 @@ Status FunctionLoad(redis::Connection *conn, const
std::string &script, bool nee
if (!s) return s;
}
+ ScriptRunCtx script_run_ctx;
+ script_run_ctx.conn = conn;
+ script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
+
+ SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
+
lua_pushstring(lua, libname.c_str());
lua_setglobal(lua, REDIS_FUNCTION_LIBNAME);
auto libname_exit = MakeScopeExit([lua] {
@@ -331,6 +337,8 @@ Status FunctionLoad(redis::Connection *conn, const
std::string &script, bool nee
return {Status::NotOK, "Error while running new function lib: " + err_msg};
}
+ RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
+
if (!FunctionIsLibExist(conn, libname, false, read_only)) {
return {Status::NotOK, "Please register some function in FUNCTION LOAD"};
}
@@ -396,6 +404,7 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
}
ScriptRunCtx script_run_ctx;
+ script_run_ctx.conn = conn;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
if (!lua_isnil(lua, -1)) {
@@ -642,6 +651,7 @@ Status EvalGenericCommand(redis::Connection *conn, const
std::string &body_or_sh
}
ScriptRunCtx current_script_run_ctx;
+ current_script_run_ctx.conn = conn;
current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites :
0;
lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname +
2).c_str());
if (!lua_isnil(lua, -1)) {
@@ -709,14 +719,6 @@ int RedisCallCommand(lua_State *lua) { return
RedisGenericCommand(lua, 1); }
int RedisPCallCommand(lua_State *lua) { return RedisGenericCommand(lua, 0); }
-Server *GetServer(lua_State *lua) {
- lua_getglobal(lua, REDIS_LUA_SERVER_PTR);
- auto srv = reinterpret_cast<Server *>(lua_touserdata(lua, -1));
- lua_pop(lua, 1);
-
- return srv;
-}
-
// TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
// so the function need to be refactored
int RedisGenericCommand(lua_State *lua, int raise_error) {
@@ -772,10 +774,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
std::string cmd_name = attributes->name;
- auto srv = GetServer(lua);
+ auto *conn = script_run_ctx->conn;
+ auto *srv = conn->GetServer();
Config *config = srv->GetConfig();
- redis::Connection *conn = srv->GetCurrentConnection();
if (config->cluster_enabled) {
if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
PushError(lua, "Can not run script on cluster, 'no-cluster' flag is
set");
@@ -901,8 +903,10 @@ int RedisReturnSingleFieldTable(lua_State *lua, const char
*field) {
}
int RedisSetResp(lua_State *lua) {
- auto srv = GetServer(lua);
- auto conn = srv->GetCurrentConnection();
+ auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua,
REGISTRY_SCRIPT_RUN_CTX_NAME);
+ CHECK_NOTNULL(script_run_ctx);
+ auto *conn = script_run_ctx->conn;
+ auto *srv = conn->GetServer();
if (lua_gettop(lua) != 1) {
PushError(lua, "redis.setresp() requires one argument.");
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 9aa4044b..188f855c 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -35,7 +35,6 @@ inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] =
"f_";
inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_";
inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] =
"__redis_registered_";
inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] =
"__redis_registered_flags_";
-inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr";
inline constexpr const char REDIS_FUNCTION_LIBNAME[] =
"REDIS_FUNCTION_LIBNAME";
inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] =
"REDIS_FUNCTION_NEEDSTORE";
inline constexpr const char REDIS_FUNCTION_LIBRARIES[] =
"REDIS_FUNCTION_LIBRARIES";
@@ -43,9 +42,8 @@ inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] =
"SCRIPT_RUN_CTX";
namespace lua {
-lua_State *CreateState(Server *srv);
+lua_State *CreateState();
void DestroyState(lua_State *lua);
-Server *GetServer(lua_State *lua);
void LoadFuncs(lua_State *lua);
void LoadLibraries(lua_State *lua);
@@ -150,6 +148,8 @@ struct ScriptRunCtx {
// and is used to detect whether there is cross-slot access
// between multiple commands in a script or function.
int current_slot = -1;
+ // the current connection
+ redis::Connection *conn = nullptr;
};
/// SaveOnRegistry saves user-defined data to lua REGISTRY