This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new b4b14f0b4 feat(script): pass storage context through scripting (#2901)
b4b14f0b4 is described below
commit b4b14f0b4bd9e4033950541daaa787dc5293e893
Author: Twice <[email protected]>
AuthorDate: Sat Apr 26 20:11:01 2025 +0800
feat(script): pass storage context through scripting (#2901)
Signed-off-by: PragmaTwice <[email protected]>
---
src/commands/cmd_function.cc | 11 +++++----
src/commands/cmd_script.cc | 2 +-
src/server/redis_connection.cc | 2 +-
src/storage/scripting.cc | 52 +++++++++++++++++++++---------------------
src/storage/scripting.h | 29 +++++++++++++----------
5 files changed, 51 insertions(+), 45 deletions(-)
diff --git a/src/commands/cmd_function.cc b/src/commands/cmd_function.cc
index 82773e5f7..1d5ef8461 100644
--- a/src/commands/cmd_function.cc
+++ b/src/commands/cmd_function.cc
@@ -39,7 +39,7 @@ struct CommandFunction : Commander {
}
std::string libname;
- auto s = lua::FunctionLoad(conn, GET_OR_RET(parser.TakeStr()), true,
replace, &libname);
+ auto s = lua::FunctionLoad(conn, &ctx, GET_OR_RET(parser.TakeStr()),
true, replace, &libname);
if (!s) return s;
*output = SimpleString(libname);
@@ -55,21 +55,21 @@ struct CommandFunction : Commander {
with_code = true;
}
- return lua::FunctionList(srv, conn, libname, with_code, output);
+ return lua::FunctionList(srv, conn, ctx, libname, with_code, output);
} else if (parser.EatEqICase("listfunc")) {
std::string funcname;
if (parser.EatEqICase("funcname")) {
funcname = GET_OR_RET(parser.TakeStr());
}
- return lua::FunctionListFunc(srv, conn, funcname, output);
+ return lua::FunctionListFunc(srv, conn, ctx, funcname, output);
} else if (parser.EatEqICase("listlib")) {
auto libname = GET_OR_RET(parser.TakeStr().Prefixed("expect a library
name"));
return lua::FunctionListLib(conn, libname, output);
} else if (parser.EatEqICase("delete")) {
auto libname = GET_OR_RET(parser.TakeStr());
- if (!lua::FunctionIsLibExist(conn, libname)) {
+ if (!lua::FunctionIsLibExist(conn, &ctx, libname)) {
return {Status::NotOK, "no such library"};
}
auto s = lua::FunctionDelete(ctx, conn, libname);
@@ -94,7 +94,8 @@ struct CommandFCall : Commander {
return {Status::NotOK, "Number of keys can't be negative"};
}
- return lua::FunctionCall(conn, args_[1],
std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
+ return lua::FunctionCall(conn, &ctx, args_[1],
+ std::vector<std::string>(args_.begin() + 3,
args_.begin() + 3 + numkeys),
std::vector<std::string>(args_.begin() + 3 +
numkeys, args_.end()), output, read_only);
}
};
diff --git a/src/commands/cmd_script.cc b/src/commands/cmd_script.cc
index 8dac03981..142076de0 100644
--- a/src/commands/cmd_script.cc
+++ b/src/commands/cmd_script.cc
@@ -43,7 +43,7 @@ class CommandEvalImpl : public Commander {
}
return lua::EvalGenericCommand(
- conn, args_[1], std::vector<std::string>(args_.begin() + 3,
args_.begin() + 3 + numkeys),
+ conn, &ctx, args_[1], std::vector<std::string>(args_.begin() + 3,
args_.begin() + 3 + numkeys),
std::vector<std::string>(args_.begin() + 3 + numkeys, args_.end()),
evalsha, output, read_only);
}
};
diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc
index 765a0e517..aa21a5595 100644
--- a/src/server/redis_connection.cc
+++ b/src/server/redis_connection.cc
@@ -373,7 +373,7 @@ static bool IsCmdAllowedInStaleData(const std::string
&cmd_name) {
void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
const Config *config = srv_->GetConfig();
std::string reply;
- std::string password = config->requirepass;
+ const std::string &password = config->requirepass;
while (!to_process_cmds->empty()) {
CommandTokens cmd_tokens = std::move(to_process_cmds->front());
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index 26bd34548..ebd194a23 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -281,8 +281,8 @@ int RedisRegisterFunction(lua_State *lua) {
return 0;
}
-Status FunctionLoad(redis::Connection *conn, const std::string &script, bool
need_to_store, bool replace,
- [[maybe_unused]] std::string *lib_name, bool read_only) {
+Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const
std::string &script, bool need_to_store,
+ bool replace, [[maybe_unused]] std::string *lib_name, bool
read_only) {
std::string first_line, lua_code;
if (auto pos = script.find('\n'); pos != std::string::npos) {
first_line = script.substr(0, pos);
@@ -296,17 +296,17 @@ Status FunctionLoad(redis::Connection *conn, const
std::string &script, bool nee
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();
- if (FunctionIsLibExist(conn, libname, need_to_store, read_only)) {
+ if (FunctionIsLibExist(conn, ctx, libname, need_to_store, read_only)) {
if (!replace) {
return {Status::NotOK, "library already exists, please specify REPLACE
to force load"};
}
- engine::Context ctx(srv->storage);
- auto s = FunctionDelete(ctx, conn, libname);
+ auto s = FunctionDelete(*ctx, conn, libname);
if (!s) return s;
}
ScriptRunCtx script_run_ctx;
script_run_ctx.conn = conn;
+ script_run_ctx.ctx = ctx;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
@@ -339,14 +339,15 @@ Status FunctionLoad(redis::Connection *conn, const
std::string &script, bool nee
RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
- if (!FunctionIsLibExist(conn, libname, false, read_only)) {
+ if (!FunctionIsLibExist(conn, ctx, libname, false, read_only)) {
return {Status::NotOK, "Please register some function in FUNCTION LOAD"};
}
return need_to_store ? srv->FunctionSetCode(libname, script) : Status::OK();
}
-bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname,
bool need_check_storage, bool read_only) {
+bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const
std::string &libname,
+ bool need_check_storage, bool read_only) {
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();
@@ -373,14 +374,15 @@ bool FunctionIsLibExist(redis::Connection *conn, const
std::string &libname, boo
if (!s) return false;
std::string lib_name;
- s = FunctionLoad(conn, code, false, false, &lib_name, read_only);
+ s = FunctionLoad(conn, ctx, code, false, false, &lib_name, read_only);
return static_cast<bool>(s);
}
// FunctionCall will firstly find the function in the lua runtime,
// if it is not found, it will try to load the library where the function is
located from storage
-Status FunctionCall(redis::Connection *conn, const std::string &name, const
std::vector<std::string> &keys,
- const std::vector<std::string> &argv, std::string *output,
bool read_only) {
+Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const
std::string &name,
+ const std::vector<std::string> &keys, const
std::vector<std::string> &argv, std::string *output,
+ bool read_only) {
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();
@@ -397,7 +399,7 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
std::string libcode;
s = srv->FunctionGetCode(libname, &libcode);
if (!s) return s;
- s = FunctionLoad(conn, libcode, false, false, &libname, read_only);
+ s = FunctionLoad(conn, ctx, libcode, false, false, &libname, read_only);
if (!s) return s;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
@@ -405,6 +407,7 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
ScriptRunCtx script_run_ctx;
script_run_ctx.conn = conn;
+ script_run_ctx.ctx = ctx;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
if (!lua_isnil(lua, -1)) {
@@ -447,13 +450,12 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
}
// list all library names and their code (enabled via `with_code`)
-Status FunctionList(Server *srv, const redis::Connection *conn, const
std::string &libname, bool with_code,
- std::string *output) {
+Status FunctionList(Server *srv, const redis::Connection *conn,
engine::Context &ctx, const std::string &libname,
+ bool with_code, std::string *output) {
std::string start_key = engine::kLuaLibCodePrefix + libname;
std::string end_key = start_key;
end_key.back()++;
- engine::Context ctx(srv->storage);
rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
rocksdb::Slice upper_bound(end_key);
read_options.iterate_upper_bound = &upper_bound;
@@ -487,12 +489,12 @@ Status FunctionList(Server *srv, const redis::Connection
*conn, const std::strin
// extension to Redis Function
// list all function names and their corresponding library names
-Status FunctionListFunc(Server *srv, const redis::Connection *conn, const
std::string &funcname, std::string *output) {
+Status FunctionListFunc(Server *srv, const redis::Connection *conn,
engine::Context &ctx, const std::string &funcname,
+ std::string *output) {
std::string start_key = engine::kLuaFuncLibPrefix + funcname;
std::string end_key = start_key;
end_key.back()++;
- engine::Context ctx(srv->storage);
rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
rocksdb::Slice upper_bound(end_key);
read_options.iterate_upper_bound = &upper_bound;
@@ -603,8 +605,9 @@ Status FunctionDelete(engine::Context &ctx,
redis::Connection *conn, const std::
return Status::OK();
}
-Status EvalGenericCommand(redis::Connection *conn, const std::string
&body_or_sha, const std::vector<std::string> &keys,
- const std::vector<std::string> &argv, bool evalsha,
std::string *output, bool read_only) {
+Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const
std::string &body_or_sha,
+ const std::vector<std::string> &keys, const
std::vector<std::string> &argv, bool evalsha,
+ std::string *output, bool read_only) {
Server *srv = conn->GetServer();
// Use the worker's private Lua VM when entering the read-only mode
lua_State *lua = conn->Owner()->Lua();
@@ -652,6 +655,7 @@ Status EvalGenericCommand(redis::Connection *conn, const
std::string &body_or_sh
ScriptRunCtx current_script_run_ctx;
current_script_run_ctx.conn = conn;
+ current_script_run_ctx.ctx = ctx;
current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites :
0;
lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname +
2).c_str());
if (!lua_isnil(lua, -1)) {
@@ -820,14 +824,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
}
std::string output;
- // TODO: make it possible for multiple redis commands in lua script to use
the same txn context.
- {
- engine::Context ctx(srv->storage);
- s = conn->ExecuteCommand(ctx, cmd_name, args, cmd.get(), &output);
- if (!s) {
- PushError(lua, s.Msg().data());
- return raise_error ? RaiseError(lua) : 1;
- }
+ s = conn->ExecuteCommand(*script_run_ctx->ctx, cmd_name, args, cmd.get(),
&output);
+ if (!s) {
+ PushError(lua, s.Msg().data());
+ return raise_error ? RaiseError(lua) : 1;
}
srv->FeedMonitorConns(conn, args);
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 3e7254528..f41cf6e9a 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -26,6 +26,7 @@
#include "lua.hpp"
#include "server/redis_connection.h"
#include "status.h"
+#include "storage/storage.h"
namespace engine {
struct Context;
@@ -62,23 +63,25 @@ int RedisSetResp(lua_State *lua);
Status CreateFunction(Server *srv, const std::string &body, std::string *sha,
lua_State *lua, bool need_to_store);
-Status EvalGenericCommand(redis::Connection *conn, const std::string
&body_or_sha, const std::vector<std::string> &keys,
- const std::vector<std::string> &argv, bool evalsha,
std::string *output,
- bool read_only = false);
+Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const
std::string &body_or_sha,
+ const std::vector<std::string> &keys, const
std::vector<std::string> &argv, bool evalsha,
+ std::string *output, bool read_only = false);
bool ScriptExists(lua_State *lua, const std::string &sha);
-Status FunctionLoad(redis::Connection *conn, const std::string &script, bool
need_to_store, bool replace,
- std::string *lib_name, bool read_only = false);
-Status FunctionCall(redis::Connection *conn, const std::string &name, const
std::vector<std::string> &keys,
- const std::vector<std::string> &argv, std::string *output,
bool read_only = false);
-Status FunctionList(Server *srv, const redis::Connection *conn, const
std::string &libname, bool with_code,
- std::string *output);
-Status FunctionListFunc(Server *srv, const redis::Connection *conn, const
std::string &funcname, std::string *output);
+Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const
std::string &script, bool need_to_store,
+ bool replace, std::string *lib_name, bool read_only =
false);
+Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const
std::string &name,
+ const std::vector<std::string> &keys, const
std::vector<std::string> &argv, std::string *output,
+ bool read_only = false);
+Status FunctionList(Server *srv, const redis::Connection *conn,
engine::Context &ctx, const std::string &libname,
+ bool with_code, std::string *output);
+Status FunctionListFunc(Server *srv, const redis::Connection *conn,
engine::Context &ctx, const std::string &funcname,
+ std::string *output);
Status FunctionListLib(redis::Connection *conn, const std::string &libname,
std::string *output);
Status FunctionDelete(engine::Context &ctx, redis::Connection *conn, const
std::string &name);
-bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname,
bool need_check_storage = true,
- bool read_only = false);
+bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const
std::string &libname,
+ bool need_check_storage = true, bool read_only =
false);
const char *RedisProtocolToLuaType(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeInt(lua_State *lua, const char *reply);
@@ -150,6 +153,8 @@ struct ScriptRunCtx {
int current_slot = -1;
// the current connection
redis::Connection *conn = nullptr;
+ // the storage context
+ engine::Context *ctx = nullptr;
};
/// SaveOnRegistry saves user-defined data to lua REGISTRY