This is an automated email from the ASF dual-hosted git repository.
hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new a86d3173 feat(script): support script flags of Eval script and
Function (#2446)
a86d3173 is described below
commit a86d31731774d165f6c3669c508eac5f2774d8de
Author: SiLe Zhou <[email protected]>
AuthorDate: Wed Aug 14 19:49:09 2024 +0800
feat(script): support script flags of Eval script and Function (#2446)
This PR is designed to support three script flags within Eval Script and
Function: `no-writes`, `no-cluster`, and `allow-cross-slot-keys`.
Before executing the Lua script, `SaveOnRegistry` stores the parsed flags
in `REGISTRY_SCRIPT_RUN_CTX_NAME`. During the execution of the Lua script,
`GetFromRegistry` retrieves the flags of the currently executing script. After
the script execution ends, it should be set to `nil`.
For APIs like `EVAL`, `SCRIPT LOAD`, and `FUNCTION LOAD`, the flags parsed
from the Eval Script will be stored in Lua's global variable `f_<sha>_flags_`.
The flags parsed by FUNCTION `register_function()` will be stored in the Lua
global variable `_registered_flags_<funcname>`.
---
src/cluster/cluster.cc | 23 ++-
src/cluster/cluster.h | 3 +-
src/server/worker.cc | 2 +-
src/storage/scripting.cc | 252 ++++++++++++++++++-----
src/storage/scripting.h | 90 ++++++++-
tests/gocase/unit/scripting/function_test.go | 273 ++++++++++++++++++++++++-
tests/gocase/unit/scripting/scripting_test.go | 278 ++++++++++++++++++++++++++
7 files changed, 867 insertions(+), 54 deletions(-)
diff --git a/src/cluster/cluster.cc b/src/cluster/cluster.cc
index a38c0868..6a90d39d 100644
--- a/src/cluster/cluster.cc
+++ b/src/cluster/cluster.cc
@@ -824,7 +824,7 @@ bool Cluster::IsWriteForbiddenSlot(int slot) const {
}
Status Cluster::CanExecByMySelf(const redis::CommandAttributes *attributes,
const std::vector<std::string> &cmd_tokens,
- redis::Connection *conn) {
+ redis::Connection *conn, lua::ScriptRunCtx
*script_run_ctx) {
std::vector<int> keys_indexes;
// No keys
@@ -849,6 +849,21 @@ Status Cluster::CanExecByMySelf(const
redis::CommandAttributes *attributes, cons
return {Status::RedisClusterDown, "Hash slot not served"};
}
+ bool cross_slot_ok = false;
+ if (script_run_ctx) {
+ if (script_run_ctx->current_slot != -1 && script_run_ctx->current_slot !=
slot) {
+ if (getNodeIDBySlot(script_run_ctx->current_slot) !=
getNodeIDBySlot(slot)) {
+ return {Status::RedisMoved, fmt::format("{} {}:{}", slot,
slots_nodes_[slot]->host, slots_nodes_[slot]->port)};
+ }
+ if (!(script_run_ctx->flags &
lua::ScriptFlagType::kScriptAllowCrossSlotKeys)) {
+ return {Status::RedisCrossSlot, "Script attempted to access keys that
do not hash to the same slot"};
+ }
+ }
+
+ script_run_ctx->current_slot = slot;
+ cross_slot_ok = true;
+ }
+
if (myself_ && myself_ == slots_nodes_[slot]) {
// We use central controller to manage the topology of the cluster.
// Server can't change the topology directly, so we record the migrated
slots
@@ -887,7 +902,11 @@ Status Cluster::CanExecByMySelf(const
redis::CommandAttributes *attributes, cons
return Status::OK(); // My master is serving this slot
}
- return {Status::RedisMoved, fmt::format("{} {}:{}", slot,
slots_nodes_[slot]->host, slots_nodes_[slot]->port)};
+ if (!cross_slot_ok) {
+ return {Status::RedisMoved, fmt::format("{} {}:{}", slot,
slots_nodes_[slot]->host, slots_nodes_[slot]->port)};
+ }
+
+ return Status::OK();
}
// Only HARD mode is meaningful to the Kvrocks cluster,
diff --git a/src/cluster/cluster.h b/src/cluster/cluster.h
index e595666c..468c154d 100644
--- a/src/cluster/cluster.h
+++ b/src/cluster/cluster.h
@@ -35,6 +35,7 @@
#include "redis_slot.h"
#include "server/redis_connection.h"
#include "status.h"
+#include "storage/scripting.h"
class ClusterNode {
public:
@@ -83,7 +84,7 @@ class Cluster {
bool IsNotMaster();
bool IsWriteForbiddenSlot(int slot) const;
Status CanExecByMySelf(const redis::CommandAttributes *attributes, const
std::vector<std::string> &cmd_tokens,
- redis::Connection *conn);
+ redis::Connection *conn, lua::ScriptRunCtx
*script_run_ctx = nullptr);
Status SetMasterSlaveRepl();
Status MigrateSlotRange(const SlotRange &slot_range, const std::string
&dst_node_id,
SyncMigrateContext *blocking_ctx = nullptr);
diff --git a/src/server/worker.cc b/src/server/worker.cc
index 22054e1f..0b37bcef 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -72,7 +72,7 @@ Worker::Worker(Server *srv, Config *config) : srv(srv),
base_(event_base_new())
LOG(INFO) << "[worker] Listening on: " << bind << ":" << *port;
}
}
- lua_ = lua::CreateState(srv, true);
+ lua_ = lua::CreateState(srv);
}
Worker::~Worker() {
diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc
index 987f9f0d..bfe1016e 100644
--- a/src/storage/scripting.cc
+++ b/src/storage/scripting.cc
@@ -57,11 +57,11 @@ enum {
namespace lua {
-lua_State *CreateState(Server *srv, bool read_only) {
+lua_State *CreateState(Server *srv) {
lua_State *lua = lua_open();
LoadLibraries(lua);
RemoveUnsupportedFunctions(lua);
- LoadFuncs(lua, read_only);
+ LoadFuncs(lua);
lua_pushlightuserdata(lua, srv);
lua_setglobal(lua, REDIS_LUA_SERVER_PTR);
@@ -75,7 +75,7 @@ void DestroyState(lua_State *lua) {
lua_close(lua);
}
-void LoadFuncs(lua_State *lua, bool read_only) {
+void LoadFuncs(lua_State *lua) {
lua_newtable(lua);
/* redis.call */
@@ -127,11 +127,6 @@ void LoadFuncs(lua_State *lua, bool read_only) {
lua_pushcfunction(lua, RedisStatusReplyCommand);
lua_settable(lua, -3);
- /* redis.read_only */
- lua_pushstring(lua, "read_only");
- lua_pushboolean(lua, read_only);
- lua_settable(lua, -3);
-
/* redis.register_function */
lua_pushstring(lua, "register_function");
lua_pushcfunction(lua, RedisRegisterFunction);
@@ -226,8 +221,8 @@ int RedisLogCommand(lua_State *lua) {
int RedisRegisterFunction(lua_State *lua) {
int argc = lua_gettop(lua);
- if (argc != 2) {
- lua_pushstring(lua, "redis.register_function() requires two arguments.");
+ if (argc < 2 || argc > 3) {
+ lua_pushstring(lua, "wrong number of arguments to
redis.register_function().");
return lua_error(lua);
}
@@ -243,6 +238,15 @@ int RedisRegisterFunction(lua_State *lua) {
// set this function to global
std::string name = lua_tostring(lua, 1);
+ if (argc == 3) {
+ auto flags = ExtractFlagsFromRegisterFunction(lua);
+ if (!flags) {
+ lua_pushstring(lua, flags.Msg().c_str());
+ return lua_error(lua);
+ }
+ lua_pushinteger(lua, static_cast<lua_Integer>(flags.GetValue()));
+ lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
+ }
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
// set this function name to REDIS_FUNCTION_LIBRARIES[libname]
@@ -274,7 +278,6 @@ int RedisRegisterFunction(lua_State *lua) {
lua_pushstring(lua, "redis.register_function() failed to store
informantion.");
return lua_error(lua);
}
-
return 0;
}
@@ -288,31 +291,8 @@ Status FunctionLoad(redis::Connection *conn, const
std::string &script, bool nee
return {Status::NotOK, "Expect a Shebang statement in the first line"};
}
- static constexpr const char *shebang_prefix = "#!lua";
- static constexpr const char *shebang_libname_prefix = "name=";
-
- auto first_line_split = util::Split(first_line, " \r\t");
- if (first_line_split.empty() || first_line_split[0] != shebang_prefix) {
- return {Status::NotOK, "Expect a Shebang statement in the first line, e.g.
`#!lua name=mylib`"};
- }
-
- size_t libname_pos = 1;
- for (; libname_pos < first_line_split.size(); ++libname_pos) {
- if (util::HasPrefix(first_line_split[libname_pos],
shebang_libname_prefix)) {
- break;
- }
- }
-
- if (libname_pos >= first_line_split.size()) {
- return {Status::NotOK, "Expect library name in the Shebang statement, e.g.
`#!lua name=mylib`"};
- }
+ const auto libname = GET_OR_RET(ExtractLibNameFromShebang(first_line));
- auto libname =
first_line_split[libname_pos].substr(strlen(shebang_libname_prefix));
- *lib_name = libname;
- if (libname.empty() ||
- std::any_of(libname.begin(), libname.end(), [](char v) { return
!std::isalnum(v) && v != '_'; })) {
- return {Status::NotOK, "Expect a valid library name in the Shebang
statement"};
- }
auto srv = conn->GetServer();
auto lua = read_only ? conn->Owner()->Lua() : srv->Lua();
@@ -409,13 +389,24 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
std::string libcode;
s = srv->FunctionGetCode(libname, &libcode);
if (!s) return s;
-
s = FunctionLoad(conn, libcode, false, false, &libname, read_only);
if (!s) return s;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
}
+ ScriptRunCtx script_run_ctx;
+ script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
+ lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
+ if (!lua_isnil(lua, -1)) {
+ // It should be ensured that the conversion is successful
+ auto function_flags = lua_tointeger(lua, -1);
+ script_run_ctx.flags |= function_flags;
+ }
+ lua_pop(lua, 1);
+
+ SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
+
PushArray(lua, keys);
PushArray(lua, argv);
if (lua_pcall(lua, 2, 1, -4)) {
@@ -427,6 +418,22 @@ Status FunctionCall(redis::Connection *conn, const
std::string &name, const std:
lua_pop(lua, 2);
}
+ RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
+
+ /* Call the Lua garbage collector from time to time to avoid a
+ * full cycle performed by Lua, which adds too latency.
+ *
+ * The call is performed every LUA_GC_CYCLE_PERIOD executed commands
+ * (and for LUA_GC_CYCLE_PERIOD collection steps) because calling it
+ * for every command uses too much CPU. */
+ constexpr int64_t LUA_GC_CYCLE_PERIOD = 50;
+ static int64_t gc_count = 0;
+
+ gc_count++;
+ if (gc_count == LUA_GC_CYCLE_PERIOD) {
+ lua_gc(lua, LUA_GCSTEP, LUA_GC_CYCLE_PERIOD);
+ gc_count = 0;
+ }
return Status::OK();
}
@@ -572,6 +579,8 @@ Status FunctionDelete(Server *srv, const std::string &name)
{
std::string func = lua_tostring(lua, -1);
lua_pushnil(lua);
lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + func).c_str());
+ lua_pushnil(lua);
+ lua_setglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + func).c_str());
auto _ = storage->Delete(rocksdb::WriteOptions(), cf,
engine::kLuaFuncLibPrefix + func);
lua_pop(lua, 1);
}
@@ -590,7 +599,6 @@ Status FunctionDelete(Server *srv, const std::string &name)
{
Status EvalGenericCommand(redis::Connection *conn, const std::string
&body_or_sha, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, bool evalsha,
std::string *output, bool read_only) {
Server *srv = conn->GetServer();
-
// Use the worker's private Lua VM when entering the read-only mode
lua_State *lua = read_only ? conn->Owner()->Lua() : srv->Lua();
@@ -635,6 +643,18 @@ Status EvalGenericCommand(redis::Connection *conn, const
std::string &body_or_sh
lua_getglobal(lua, funcname);
}
+ ScriptRunCtx current_script_run_ctx;
+ current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites :
0;
+ lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname +
2).c_str());
+ if (!lua_isnil(lua, -1)) {
+ // It should be ensured that the conversion is successful
+ auto script_flags = lua_tointeger(lua, -1);
+ current_script_run_ctx.flags |= script_flags;
+ }
+ lua_pop(lua, 1);
+
+ SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, ¤t_script_run_ctx);
+
// For the Lua script, should be always run with RESP2 protocol,
// unless users explicitly set the protocol version in the script via
`redis.setresp`.
// So we need to save the current protocol version and set it to RESP2,
@@ -661,6 +681,7 @@ Status EvalGenericCommand(redis::Connection *conn, const
std::string &body_or_sh
lua_setglobal(lua, "KEYS");
lua_pushnil(lua);
lua_setglobal(lua, "ARGV");
+ RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);
/* Call the Lua garbage collector from time to time to avoid a
* full cycle performed by Lua, which adds too latency.
@@ -701,10 +722,8 @@ Server *GetServer(lua_State *lua) {
// TODO: we do not want to repeat same logic as Connection::ExecuteCommands,
// so the function need to be refactored
int RedisGenericCommand(lua_State *lua, int raise_error) {
- lua_getglobal(lua, "redis");
- lua_getfield(lua, -1, "read_only");
- int read_only = lua_toboolean(lua, -1);
- lua_pop(lua, 2);
+ auto *script_run_ctx = GetFromRegistry<ScriptRunCtx>(lua,
REGISTRY_SCRIPT_RUN_CTX_NAME);
+ CHECK_NOTNULL(script_run_ctx);
int argc = lua_gettop(lua);
if (argc == 0) {
@@ -738,7 +757,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
auto attributes = cmd->GetAttributes();
auto cmd_flags = attributes->GenerateFlags(args);
- if (read_only && !(cmd_flags & redis::kCmdReadOnly)) {
+ if ((script_run_ctx->flags & ScriptFlagType::kScriptNoWrites) && !(cmd_flags
& redis::kCmdReadOnly)) {
PushError(lua, "Write commands are not allowed from read-only scripts");
return raise_error ? RaiseError(lua) : 1;
}
@@ -760,9 +779,17 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
redis::Connection *conn = srv->GetCurrentConnection();
if (config->cluster_enabled) {
- auto s = srv->cluster->CanExecByMySelf(attributes, args, conn);
+ if (script_run_ctx->flags & ScriptFlagType::kScriptNoCluster) {
+ PushError(lua, "Can not run script on cluster, 'no-cluster' flag is
set");
+ return raise_error ? RaiseError(lua) : 1;
+ }
+ auto s = srv->cluster->CanExecByMySelf(attributes, args, conn,
script_run_ctx);
if (!s.IsOK()) {
- PushError(lua, redis::StatusToRedisErrorMsg(s).c_str());
+ if (s.Is<Status::RedisMoved>()) {
+ PushError(lua, "Script attempted to access a non local key in a
cluster node script");
+ } else {
+ PushError(lua, redis::StatusToRedisErrorMsg(s).c_str());
+ }
return raise_error ? RaiseError(lua) : 1;
}
}
@@ -1463,7 +1490,23 @@ Status CreateFunction(Server *srv, const std::string
&body, std::string *sha, lu
std::copy(sha->begin(), sha->end(), funcname + 2);
}
- if (luaL_loadbuffer(lua, body.c_str(), body.size(), "@user_script")) {
+ std::string_view lua_code(body);
+ // Cache the flags of the current script
+ ScriptFlags script_flags = 0;
+ if (auto pos = body.find('\n'); pos != std::string::npos) {
+ std::string_view first_line(body.data(), pos);
+ if (first_line.substr(0, 2) == "#!") {
+ lua_code = lua_code.substr(pos + 1);
+ }
+ script_flags = GET_OR_RET(ExtractFlagsFromShebang(first_line));
+ } else {
+ // scripts without #! can run commands that access keys belonging to
different cluster hash slots
+ script_flags = kScriptAllowCrossSlotKeys;
+ }
+ lua_pushinteger(lua, static_cast<lua_Integer>(script_flags));
+ lua_setglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, *sha).c_str());
+
+ if (luaL_loadbuffer(lua, lua_code.data(), lua_code.size(), "@user_script")) {
std::string err_msg = lua_tostring(lua, -1);
lua_pop(lua, 1);
return {Status::NotOK, "Error while compiling new script: " + err_msg};
@@ -1474,4 +1517,123 @@ Status CreateFunction(Server *srv, const std::string
&body, std::string *sha, lu
return need_to_store ? srv->ScriptSet(*sha, body) : Status::OK();
}
+[[nodiscard]] StatusOr<std::string> ExtractLibNameFromShebang(std::string_view
shebang) {
+ static constexpr std::string_view lua_shebang_prefix = "#!lua";
+ static constexpr std::string_view shebang_libname_prefix = "name=";
+
+ if (shebang.substr(0, 2) != "#!") {
+ return {Status::NotOK, "Missing library meta"};
+ }
+
+ auto shebang_splits = util::Split(shebang, " ");
+ if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) {
+ return {Status::NotOK, "Unexpected engine in script shebang: " +
shebang_splits[0]};
+ }
+
+ std::string libname;
+ bool found_libname = false;
+ for (size_t i = 1; i < shebang_splits.size(); i++) {
+ std::string_view shebang_split_sv = shebang_splits[i];
+ if (shebang_split_sv.substr(0, shebang_libname_prefix.size()) !=
shebang_libname_prefix) {
+ return {Status::NotOK, "Unknown lua shebang option: " +
shebang_splits[i]};
+ }
+ if (found_libname) {
+ return {Status::NotOK, "Redundant library name in script shebang"};
+ }
+
+ libname = shebang_split_sv.substr(shebang_libname_prefix.size());
+ if (libname.empty() ||
+ std::any_of(libname.begin(), libname.end(), [](char v) { return
!std::isalnum(v) && v != '_'; })) {
+ return {
+ Status::NotOK,
+ "Library names can only contain letters, numbers, or underscores(_)
and must be at least one character long"};
+ }
+ found_libname = true;
+ }
+
+ if (found_libname) return libname;
+ return {Status::NotOK, "Library name was not given"};
+}
+
+[[nodiscard]] StatusOr<ScriptFlags> GetFlagsFromStrings(const
std::vector<std::string> &flags_content) {
+ ScriptFlags flags = 0;
+ for (const auto &flag : flags_content) {
+ if (flag == "no-writes") {
+ flags |= kScriptNoWrites;
+ } else if (flag == "allow-oom") {
+ return {Status::NotSupported, "allow-oom is not supported yet"};
+ } else if (flag == "allow-stale") {
+ return {Status::NotSupported, "allow-stale is not supported yet"};
+ } else if (flag == "no-cluster") {
+ flags |= kScriptNoCluster;
+ } else if (flag == "allow-cross-slot-keys") {
+ flags |= kScriptAllowCrossSlotKeys;
+ } else {
+ return {Status::NotOK, "Unknown flag given: " + flag};
+ }
+ }
+ return flags;
+}
+
+[[nodiscard]] StatusOr<ScriptFlags> ExtractFlagsFromShebang(std::string_view
shebang) {
+ static constexpr std::string_view lua_shebang_prefix = "#!lua";
+ static constexpr std::string_view shebang_flags_prefix = "flags=";
+
+ ScriptFlags result_flags = 0;
+ if (shebang.substr(0, 2) == "#!") {
+ auto shebang_splits = util::Split(shebang, " ");
+ if (shebang_splits.empty() || shebang_splits[0] != lua_shebang_prefix) {
+ return {Status::NotOK, "Unexpected engine in script shebang: " +
shebang_splits[0]};
+ }
+ bool found_flags = false;
+ for (size_t i = 1; i < shebang_splits.size(); i++) {
+ std::string_view shebang_split_sv = shebang_splits[i];
+ if (shebang_split_sv.substr(0, shebang_flags_prefix.size()) !=
shebang_flags_prefix) {
+ return {Status::NotOK, "Unknown lua shebang option: " +
shebang_splits[i]};
+ }
+ if (found_flags) {
+ return {Status::NotOK, "Redundant flags in script shebang"};
+ }
+ auto flags_content =
util::Split(shebang_split_sv.substr(shebang_flags_prefix.size()), ",");
+ result_flags |= GET_OR_RET(GetFlagsFromStrings(flags_content));
+ found_flags = true;
+ }
+ } else {
+ // scripts without #! can run commands that access keys belonging to
different cluster hash slots,
+ // but ones with #! inherit the default flags, so they cannot.
+ result_flags = kScriptAllowCrossSlotKeys;
+ }
+
+ return result_flags;
+}
+
+[[nodiscard]] StatusOr<ScriptFlags> ExtractFlagsFromRegisterFunction(lua_State
*lua) {
+ if (!lua_istable(lua, -1)) {
+ return {Status::NotOK, "Expects a valid flags argument to
register_function, e.g. flags={ 'no-writes' }"};
+ }
+ auto flag_count = static_cast<int>(lua_objlen(lua, -1));
+ std::vector<std::string> flags_content;
+ flags_content.reserve(flag_count);
+ for (int i = 1; i <= flag_count; ++i) {
+ lua_pushnumber(lua, i);
+ lua_gettable(lua, -2);
+ if (!lua_isstring(lua, -1)) {
+ return {Status::NotOK, "Expects a valid flags argument to
register_function, e.g. flags={ 'no-writes' }"};
+ }
+ flags_content.emplace_back(lua_tostring(lua, -1));
+ // pop up the current flag
+ lua_pop(lua, 1);
+ }
+ // pop up the corresponding table of the flags parameter
+ lua_pop(lua, 1);
+
+ return GetFlagsFromStrings(flags_content);
+}
+
+void RemoveFromRegistry(lua_State *lua, const char *name) {
+ lua_pushstring(lua, name);
+ lua_pushnil(lua);
+ lua_settable(lua, LUA_REGISTRYINDEX);
+}
+
} // namespace lua
diff --git a/src/storage/scripting.h b/src/storage/scripting.h
index 3b2dd45d..6cfa31f0 100644
--- a/src/storage/scripting.h
+++ b/src/storage/scripting.h
@@ -28,19 +28,22 @@
#include "status.h"
inline constexpr const char REDIS_LUA_FUNC_SHA_PREFIX[] = "f_";
+inline constexpr const char REDIS_LUA_FUNC_SHA_FLAGS[] = "f_{}_flags_";
inline constexpr const char REDIS_LUA_REGISTER_FUNC_PREFIX[] =
"__redis_registered_";
+inline constexpr const char REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX[] =
"__redis_registered_flags_";
inline constexpr const char REDIS_LUA_SERVER_PTR[] = "__server_ptr";
inline constexpr const char REDIS_FUNCTION_LIBNAME[] =
"REDIS_FUNCTION_LIBNAME";
inline constexpr const char REDIS_FUNCTION_NEEDSTORE[] =
"REDIS_FUNCTION_NEEDSTORE";
inline constexpr const char REDIS_FUNCTION_LIBRARIES[] =
"REDIS_FUNCTION_LIBRARIES";
+inline constexpr const char REGISTRY_SCRIPT_RUN_CTX_NAME[] = "SCRIPT_RUN_CTX";
namespace lua {
-lua_State *CreateState(Server *srv, bool read_only = false);
+lua_State *CreateState(Server *srv);
void DestroyState(lua_State *lua);
Server *GetServer(lua_State *lua);
-void LoadFuncs(lua_State *lua, bool read_only = false);
+void LoadFuncs(lua_State *lua);
void LoadLibraries(lua_State *lua);
void RemoveUnsupportedFunctions(lua_State *lua);
void EnableGlobalsProtection(lua_State *lua);
@@ -101,4 +104,87 @@ void SHA1Hex(char *digest, const char *script, size_t len);
int RedisMathRandom(lua_State *l);
int RedisMathRandomSeed(lua_State *l);
+/// ScriptFlagType turn on/off constraints or indicate properties in Eval
scripts and functions
+///
+/// Note: The default for Eval scripts are different than the default for
functions(default is 0).
+/// As soon as Redis sees the #! comment, it'll treat the script as if it
declares flags, even if no flags are defined,
+/// it still has a different set of defaults compared to a script without a #!
line.
+/// Another difference is that scripts without #! can run commands that access
keys belonging to different cluster hash
+/// slots, but ones with #! inherit the default flags, so they cannot.
+enum ScriptFlagType : uint64_t {
+ kScriptNoWrites = 1ULL << 0, // "no-writes" flag
+ kScriptAllowOom = 1ULL << 1, // "allow-oom" flag
+ kScriptAllowStale = 1ULL << 2, // "allow-stale" flag
+ kScriptNoCluster = 1ULL << 3, // "no-cluster" flag
+ kScriptAllowCrossSlotKeys = 1ULL << 4, // "allow-cross-slot-keys" flag
+};
+
+/// ScriptFlags is composed of one or more ScriptFlagTypes combined by an OR
operation
+/// For example, ScriptFlags flags = kScriptNoWrites | kScriptNoCluster
+using ScriptFlags = uint64_t;
+
+[[nodiscard]] StatusOr<std::string> ExtractLibNameFromShebang(std::string_view
shebang);
+[[nodiscard]] StatusOr<ScriptFlags> ExtractFlagsFromShebang(std::string_view
shebang);
+
+/// GetFlagsFromStrings gets flags from flags_content and composites them
together.
+/// Each element in flags_content should correspond to a string form of
ScriptFlagType
+[[nodiscard]] StatusOr<ScriptFlags> GetFlagsFromStrings(const
std::vector<std::string> &flags_content);
+
+/// ExtractFlagsFromRegisterFunction extracts the flags from the
redis.register_function
+///
+/// Note: When using it, you should make sure that
+/// the top of the stack of lua is the flags parameter of
redis.register_function.
+/// The flags parameter in Lua is a table that stores strings.
+/// After use, the original flags table on the top of the stack will be popped.
+[[nodiscard]] StatusOr<ScriptFlags> ExtractFlagsFromRegisterFunction(lua_State
*lua);
+
+/// ScriptRunCtx is used to record context information during the running of
Eval scripts and functions.
+struct ScriptRunCtx {
+ // ScriptFlags
+ uint64_t flags = 0;
+ // current_slot tracks the slot currently accessed by the script
+ // and is used to detect whether there is cross-slot access
+ // between multiple commands in a script or function.
+ int current_slot = -1;
+};
+
+/// SaveOnRegistry saves user-defined data to lua REGISTRY
+///
+/// Note: Since lua_pushlightuserdata, you need to manage the life cycle of
the data stored in the Registry yourself.
+template <typename T>
+void SaveOnRegistry(lua_State *lua, const char *name, T *ptr) {
+ lua_pushstring(lua, name);
+ if (ptr) {
+ lua_pushlightuserdata(lua, ptr);
+ } else {
+ lua_pushnil(lua);
+ }
+ lua_settable(lua, LUA_REGISTRYINDEX);
+}
+
+template <typename T>
+T *GetFromRegistry(lua_State *lua, const char *name) {
+ lua_pushstring(lua, name);
+ lua_gettable(lua, LUA_REGISTRYINDEX);
+
+ if (lua_isnil(lua, -1)) {
+ // pops the value
+ lua_pop(lua, 1);
+ return nullptr;
+ }
+
+ // must be light user data
+ CHECK(lua_islightuserdata(lua, -1));
+ auto *ptr = static_cast<T *>(lua_touserdata(lua, -1));
+
+ CHECK_NOTNULL(ptr);
+
+ // pops the value
+ lua_pop(lua, 1);
+
+ return ptr;
+}
+
+void RemoveFromRegistry(lua_State *lua, const char *name);
+
} // namespace lua
diff --git a/tests/gocase/unit/scripting/function_test.go
b/tests/gocase/unit/scripting/function_test.go
index d0c12ec1..8d711584 100644
--- a/tests/gocase/unit/scripting/function_test.go
+++ b/tests/gocase/unit/scripting/function_test.go
@@ -22,6 +22,7 @@ package scripting
import (
"context"
_ "embed"
+ "fmt"
"strings"
"testing"
@@ -116,13 +117,13 @@ var testFunctions = func(t *testing.T, enabledRESP3
string) {
t.Run("FUNCTION LOAD errors", func(t *testing.T) {
code := strings.Join(strings.Split(luaMylib1, "\n")[1:], "\n")
- util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD",
code).Err(), ".*Shebang statement.*")
+ require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code).Err(),
"ERR Missing library meta")
code2 := "#!lua\n" + code
- util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD",
code2).Err(), ".*Expect library name.*")
+ require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(),
"ERR Library name was not given")
code2 = "#!lua name=$$$\n" + code
- util.ErrorRegexp(t, rdb.Do(ctx, "FUNCTION", "LOAD",
code2).Err(), ".*valid library name.*")
+ require.Error(t, rdb.Do(ctx, "FUNCTION", "LOAD", code2).Err(),
"ERR Library names can only contain letters, numbers, or underscores(_) and
must be at least one character long")
})
t.Run("FUNCTION LOAD and FCALL mylib1", func(t *testing.T) {
@@ -277,3 +278,269 @@ var testFunctions = func(t *testing.T, enabledRESP3
string) {
}, decodeListLibResult(t, r))
})
}
+
+func TestFunctionScriptFlags(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("Function extract-libname-error", func(t *testing.T) {
+ r := rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=mylibname flags=no-writes
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua flags=no-writes name=mylibname
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=mylibname name=mylibname2
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "Redundant library name in script
shebang")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!errorenine name=mylibname
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!luaname=mylibname
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua xxxname=mylibname
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=mylibname key=value
+ redis.register_function('extract_libname_error_func',
function(keys, args) end)`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+ })
+
+ t.Run("Function extract-flags-error", func(t *testing.T) {
+ r := rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=myflags
+ redis.register_function('extract_flags_error_func',
function(keys, args) end, { 'invalid-flag' })`)
+ require.Error(t, r.Err(), "ERR Error while running new function
lib: Unknown flag given: invalid-flag")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=myflags
+ redis.register_function('extract_flags_error_func',
function(keys, args) end, { 'no-writes', 'invalid-flag' })`)
+ require.Error(t, r.Err(), "ERR Error while running new function
lib: Unknown flag given: invalid-flag")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=myflags
+ redis.register_function('extract_flags_error_func',
function(keys, args) end, { {} }`)
+ require.Error(t, r.Err(), "ERR Expects a valid flags argument
to register_function, e.g. flags={ 'no-writes' })")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=myflags
+ redis.register_function('extract_flags_error_func',
function(keys, args) end, { 123 }`)
+ require.Error(t, r.Err(), "ERR Expects a valid flags argument
to register_function, e.g. flags={ 'no-writes' })")
+
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=myflags
+ redis.register_function('extract_flags_error_func',
function(keys, args) end, 'no-writes'`)
+ require.Error(t, r.Err(), "ERR Expects a valid flags argument
to register_function, e.g. flags={ 'no-writes' })")
+ })
+
+ t.Run("no-writes", func(t *testing.T) {
+ r := rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=nowriteslib
+ redis.register_function('default_flag_func', function(keys,
args) return redis.call("set", keys[1], args[1]) end)
+ redis.register_function('no_writes_func', function(keys, args)
return redis.call("set", keys[1], args[1]) end, { 'no-writes' })`)
+ require.NoError(t, r.Err())
+
+ r = rdb.Do(ctx, "FCALL", "default_flag_func", 1, "k1", "v1")
+ require.NoError(t, r.Err())
+ r = rdb.Do(ctx, "FCALL", "no_writes_func", 1, "k2", "v2")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ r = rdb.Do(ctx, "FCALL_RO", "default_flag_func", 1, "k1", "v1")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+ r = rdb.Do(ctx, "FCALL_RO", "no_writes_func", 1, "k2", "v2")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+ })
+
+ srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+ rdb0 := srv0.NewClient()
+ defer func() { require.NoError(t, rdb0.Close()) }()
+ defer func() { srv0.Close() }()
+ id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00"
+ require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err())
+
+ srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+ srv1Alive := true
+ defer func() {
+ if srv1Alive {
+ srv1.Close()
+ }
+ }()
+
+ rdb1 := srv1.NewClient()
+ defer func() { require.NoError(t, rdb1.Close()) }()
+ id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01"
+ require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err())
+
+ clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0,
srv0.Host(), srv0.Port())
+ clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1,
srv1.Host(), srv1.Port())
+ require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes,
"1").Err())
+ require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes,
"1").Err())
+
+ // Node0: bar-slot = 5061, test-slot = 6918
+ // Node1: foo-slot = 12182
+ // Different slots of different nodes are not affected by
allow-cross-slot-keys,
+ // and different slots of the same node can be allowed
+ require.NoError(t, rdb0.Set(ctx, "bar", "bar_value", 0).Err())
+ require.NoError(t, rdb0.Set(ctx, "test", "test_value", 0).Err())
+ require.NoError(t, rdb1.Set(ctx, "foo", "foo_value", 0).Err())
+
+ t.Run("no-cluster", func(t *testing.T) {
+ r := rdb0.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=noclusterlib
+ redis.register_function('default_flag_func', function(keys)
return redis.call('get', keys[1]) end)
+ redis.register_function('no_cluster_func', function(keys)
return redis.call('get', keys[1]) end, { 'no-cluster' })`)
+ require.NoError(t, r.Err())
+
+ require.NoError(t, rdb0.Do(ctx, "FCALL", "default_flag_func",
1, "bar").Err())
+
+ r = rdb0.Do(ctx, "FCALL", "no_cluster_func", 1, "bar")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on
cluster, 'no-cluster' flag is set")
+
+ // Only valid in cluster mode
+ require.NoError(t, rdb.Set(ctx, "bar", "rdb_bar_value",
0).Err())
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=noclusterlib
+ redis.register_function('no_cluster_func', function(keys)
return redis.call('get', keys[1]) end, { 'no-cluster' })`)
+ require.NoError(t, r.Err())
+ require.NoError(t, rdb.Do(ctx, "FCALL", "no_cluster_func", 1,
"bar").Err())
+ })
+
+ t.Run("allow-cross-slot-keys", func(t *testing.T) {
+ r := rdb0.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=crossslotlib
+ redis.register_function('default_flag_func_1',
+ function()
+ redis.call('get', 'bar')
+ return redis.call('get', 'test')
+ end
+ )
+
+ redis.register_function('default_flag_func_2',
+ function()
+ redis.call('get', 'bar')
+ return redis.call('get', 'foo')
+ end
+ )
+
+ redis.register_function('default_flag_func_3',
+ function(keys)
+ redis.call('get', keys[1])
+ return redis.call('get', keys[2])
+ end
+ )
+
+ redis.register_function(
+ 'allow_cross_slot_keys_func_1',
+ function()
+ redis.call('get', 'bar')
+ return redis.call('get', 'test')
+ end,
+ { 'allow-cross-slot-keys' })
+
+ redis.register_function(
+ 'allow_cross_slot_keys_func_2',
+ function()
+ redis.call('get', 'bar')
+ return redis.call('get', 'foo')
+ end,
+ { 'allow-cross-slot-keys' })
+
+ redis.register_function(
+ 'allow_cross_slot_keys_func_3',
+ function(keys)
+ redis.call('get', key[1])
+ return redis.call('get', key[2])
+ end,
+ { 'allow-cross-slot-keys' })
+
+ `)
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "FCALL", "default_flag_func_1", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
keys that do not hash to the same slot")
+
+ r = rdb0.Do(ctx, "FCALL", "default_flag_func_2", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ r = rdb0.Do(ctx, "FCALL", "default_flag_func_3", 2, "bar",
"test")
+ require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access
keys that don't hash to the same slot")
+
+ r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_1", 0)
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_2", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ // Pre-declared keys are not affected by allow-cross-slot-keys
+ r = rdb0.Do(ctx, "FCALL", "allow_cross_slot_keys_func_3", 2,
"bar", "test")
+ require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access
keys that don't hash to the same slot")
+ })
+
+ t.Run("mixed-use", func(t *testing.T) {
+ r := rdb0.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=mixeduselib
+ redis.register_function('no_write_cluster_func_1', function()
redis.call('get', 'bar') end, { 'no-writes', 'no-cluster' })
+
+ redis.register_function('no_write_allow_cross_func_1',
+ function() redis.call('get', 'bar'); return redis.call('get',
'test'); end,
+ { 'no-writes', 'allow-cross-slot-keys' })
+
+ redis.register_function('no_write_allow_cross_func_2',
+ function() redis.call('set', 'bar'); return redis.call('set',
'test'); end,
+ { 'no-writes', 'allow-cross-slot-keys' })
+
+ redis.register_function('no_write_allow_cross_func_3',
+ function() redis.call('get', 'bar'); return redis.call('get',
'foo'); end,
+ { 'no-writes', 'allow-cross-slot-keys' })
+ `)
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "FCALL", "no_write_cluster_func_1", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on
cluster, 'no-cluster' flag is set")
+
+ // no-cluster Only valid in cluster mode
+ r = rdb.Do(ctx, "FUNCTION", "LOAD",
+ `#!lua name=mixeduselib2
+ redis.register_function('no_write_cluster_func_2',
+ function() return redis.call('set', 'bar', 'bar_value') end,
+ { 'no-writes', 'no-cluster' }
+ )
+
+ redis.register_function('no_write_cluster_func_3',
+ function() return redis.call('get', 'bar') end,
+ { 'no-writes', 'no-cluster' }
+ )
+ `)
+ require.NoError(t, r.Err())
+
+ r = rdb.Do(ctx, "FCALL", "no_write_cluster_func_2", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ require.NoError(t, rdb.Set(ctx, "bar", "bar_value_rdb",
0).Err())
+ require.NoError(t, rdb.Do(ctx, "FCALL",
"no_write_cluster_func_3", 0).Err())
+
+ require.NoError(t, rdb0.Do(ctx, "FCALL",
"no_write_allow_cross_func_1", 0).Err())
+ r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_2", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+ r = rdb0.Do(ctx, "FCALL", "no_write_allow_cross_func_3", 0)
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+ })
+}
diff --git a/tests/gocase/unit/scripting/scripting_test.go
b/tests/gocase/unit/scripting/scripting_test.go
index 50680f7d..9e4f2752 100644
--- a/tests/gocase/unit/scripting/scripting_test.go
+++ b/tests/gocase/unit/scripting/scripting_test.go
@@ -480,6 +480,7 @@ math.randomseed(ARGV[1]); return tostring(math.random())
t.Run("EVALSHA_RO - cannot run write commands", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "foo", "bar", 0).Err())
+ // sha1 of `redis.call('del', KEYS[1]);`
r := rdb.Do(ctx, "EVALSHA_RO",
"a1e63e1cd1bd1d5413851949332cfb9da4ee6dc0", "1", "foo")
util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
})
@@ -607,3 +608,280 @@ func TestScriptingWithRESP3(t *testing.T) {
require.EqualValues(t, []interface{}{"f1", "v1"}, vals)
})
}
+
+func TestEvalScriptFlags(t *testing.T) {
+ srv := util.StartServer(t, map[string]string{})
+ defer srv.Close()
+
+ ctx := context.Background()
+ rdb := srv.NewClient()
+ defer func() { require.NoError(t, rdb.Close()) }()
+
+ t.Run("Eval extract-flags-error", func(t *testing.T) {
+ r := rdb.Do(ctx, "EVAL",
+ `#!lua name=mylib
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes name=mylib
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua erroroption=no-writes
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=invalid-flag
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,invalid-flag
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes no-cluster
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes flags=no-cluster
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script
shebang")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!errorengine flags=no-writes
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!luaflags=no-writes
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua xxflags=no-writes
+ return 'extract-flags'`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+ })
+
+ t.Run("SCRIPT LOAD extract-flags-error", func(t *testing.T) {
+ r := rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua name=mylib
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua flags=no-writes name=mylib
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua erroroption=no-writes
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua flags=invalid-flag
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua flags=no-writes,invalid-flag
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown flag given::*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua flags=no-writes no-cluster
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua flags=no-writes flags=no-cluster
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Redundant flags in script
shebang")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!errorengine flags=no-writes
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!luaflags=no-writes
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unexpected engine in script
shebang:*")
+
+ r = rdb.Do(ctx, "SCRIPT", "LOAD",
+ `#!lua xxflags=no-writes
+ return 'extract-flags'`)
+ util.ErrorRegexp(t, r.Err(), "ERR Unknown lua shebang option:*")
+ })
+
+ t.Run("no-writes", func(t *testing.T) {
+ r := rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes
+ return redis.call('set', 'k1','v1');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ r = rdb.Do(ctx, "EVAL", `return redis.call('set', 'k2','v2');`,
"0")
+ require.NoError(t, r.Err())
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua
+ return redis.call('set', 'k3','v3');`, "0")
+ require.NoError(t, r.Err())
+
+ r = rdb.Do(ctx, "EVAL_RO",
+ `return redis.call('set', 'k4','v4');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ r = rdb.Do(ctx, "EVAL_RO",
+ `#!lua
+ return redis.call('set', 'k5','v5');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ r = rdb.Do(ctx, "EVAL_RO",
+ `#!lua flags=no-writes
+ return redis.call('set', 'k6','v6');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+ })
+
+ srv0 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+ rdb0 := srv0.NewClient()
+ defer func() { require.NoError(t, rdb0.Close()) }()
+ defer func() { srv0.Close() }()
+ id0 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00"
+ require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODEID", id0).Err())
+
+ srv1 := util.StartServer(t, map[string]string{"cluster-enabled": "yes"})
+ srv1Alive := true
+ defer func() {
+ if srv1Alive {
+ srv1.Close()
+ }
+ }()
+
+ rdb1 := srv1.NewClient()
+ defer func() { require.NoError(t, rdb1.Close()) }()
+ id1 := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01"
+ require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODEID", id1).Err())
+
+ clusterNodes := fmt.Sprintf("%s %s %d master - 0-10000\n", id0,
srv0.Host(), srv0.Port())
+ clusterNodes += fmt.Sprintf("%s %s %d master - 10001-16383", id1,
srv1.Host(), srv1.Port())
+ require.NoError(t, rdb0.Do(ctx, "clusterx", "SETNODES", clusterNodes,
"1").Err())
+ require.NoError(t, rdb1.Do(ctx, "clusterx", "SETNODES", clusterNodes,
"1").Err())
+
+ t.Run("no-cluster", func(t *testing.T) {
+ r := rdb0.Do(ctx, "EVAL",
+ `#!lua flags=no-cluster
+ return redis.call('set', 'k','v');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on
cluster, 'no-cluster' flag is set")
+
+ // Only valid in cluster mode
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-cluster
+ return redis.call('set', 'k','v');`, "0")
+ require.NoError(t, r.Err())
+
+ // Scripts without #! can run commands that access keys
belonging to different cluster hash slots,
+ // but ones with #! inherit the default flags, so they cannot.
+ r = rdb0.Do(ctx, "EVAL", `return redis.call('set', 'k','v');`,
"0")
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua
+ return redis.call('set', 'k','v');`, "0")
+ require.NoError(t, r.Err())
+ })
+
+ t.Run("allow-cross-slot-keys", func(t *testing.T) {
+ // Node0: bar-slot = 5061, test-slot = 6918
+ // Node1: foo-slot = 12182
+ // Different slots of different nodes are not affected by
allow-cross-slot-keys,
+ // and different slots of the same node can be allowed
+ r := rdb0.Do(ctx, "EVAL",
+ `#!lua flags=allow-cross-slot-keys
+ redis.call('set', 'bar','value_bar');
+ return redis.call('set', 'test', 'value_test');`, "0")
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua flags=allow-cross-slot-keys
+ redis.call('set', 'foo','value_foo');
+ return redis.call('set', 'bar', 'value_bar');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ // There is a shebang prefix #!lua but crossslot is not allowed
when flags are not set
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua
+ redis.call('get', 'bar');
+ return redis.call('get', 'test');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
keys that do not hash to the same slot")
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua
+ redis.call('get', 'foo');
+ return redis.call('get', 'bar');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ // Old style: CrossSlot is allowed when there is neither #!lua
nor flags set
+ r = rdb0.Do(ctx, "EVAL",
+ `redis.call('get', 'bar');
+ return redis.call('get', 'test');`, "0")
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "EVAL",
+ `redis.call('get', 'foo');
+ return redis.call('get', 'bar');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ // Pre-declared keys are not affected by allow-cross-slot-keys
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua flags=allow-cross-slot-keys
+ local key = redis.call('get', KEY[1]);
+ return redis.call('get', KEY[2]);`, "2", "bar", "test")
+ require.EqualError(t, r.Err(), "CROSSSLOT Attempted to access
keys that don't hash to the same slot")
+ })
+
+ t.Run("mixed use", func(t *testing.T) {
+ r := rdb0.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,no-cluster
+ return redis.call('get', 'key_a');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Can not run script on
cluster, 'no-cluster' flag is set")
+
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,no-cluster
+ return redis.call('set', 'key_a', 'value_a');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ err := rdb.Set(ctx, "key_a", "value_a", 0).Err()
+ require.NoError(t, err)
+ r = rdb.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,no-cluster
+ return redis.call('get', 'key_a');`, "0")
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,allow-cross-slot-keys
+ redis.call('get', 'bar');
+ return redis.call('get', 'test');`, "0")
+ require.NoError(t, r.Err())
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,allow-cross-slot-keys
+ redis.call('set', 'bar');
+ return redis.call('set', 'test');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Write commands are not
allowed from read-only scripts")
+
+ r = rdb0.Do(ctx, "EVAL",
+ `#!lua flags=no-writes,allow-cross-slot-keys
+ redis.call('get', 'bar');
+ return redis.call('get', 'foo');`, "0")
+ util.ErrorRegexp(t, r.Err(), "ERR .* Script attempted to access
a non local key in a cluster node script")
+
+ })
+}